In [1]:
import pandas as pd
import torch
import nltk
from nltk.tokenize import sent_tokenize
from datasets import load_dataset
from tqdm.notebook import tqdm
import random

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)


# The models the authors used:
try:
    from transformers import BertConfig, BertForMaskedLM, BertTokenizer, AdamW, get_linear_schedule_with_warmup
    from transformers import AlbertForMaskedLM, AlbertTokenizer
except ModuleNotFoundError:
    %pip install transformers
    from transformers import BertForMaskedLM, BertTokenizer, AdamW, get_linear_schedule_with_warmup
    from transformers import AlbertForMaskedLM, AlbertTokenizer

nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Liora\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [3]:
DailyNews_ds = load_dataset('json', data_files='../datasets/DailyNews_300.json', split='train')
DailyNews_ds

Dataset({
    features: ['scores', 'text', 'summary', 'annotators_ids'],
    num_rows: 300
})

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case = True, seed = SEED)
model = BertForMaskedLM.from_pretrained('bert-base-uncased').to(DEVICE)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.weight', 'bert.pooler.dense.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
def mask_sentence(sentence, i, M, L_min):
    return [tokenizer.mask_token
            if (j - i) % M == 0
            and (len(sentence[j]) >= L_min
                 or sentence[j].startswith('##')
                 or sentence[min(j+1, len(sentence)-1)].startswith('##'))
            else sentence[j]
            for j in range(len(sentence))]

In [6]:
def BLANC_help_modified(text, model, model_tuned, M = 6, L_min = 4, device = DEVICE):
    """
    Calculate BLANC similarity between summary and text using a specified model.

    Parameters:
    - summary (str): The summary text.
    - text (List[List[str]]): List of sentences represented as a list of words.
    - model: BERT model type
    - M (int): Parameter M for the algorithm (default is 6).
    - L_min (int): Minimum length requirement for masked words (default is 4).
    - sep (str): Separator between the inference help (filler/summary) and a sentence from the text (default is ' ').

    Returns:
    - float: BLANC similarity score.
    """

    S = [[0, 0], [0, 0]]

    for sentence in text:
        for i in range(M):
            masked_sentence = mask_sentence(sentence, i, M, L_min)

            tokenized_masked_sentence = torch.tensor(tokenizer.convert_tokens_to_ids(masked_sentence)).to(device) # Shape: [sequence_length]

            out_base = model(input_ids = tokenized_masked_sentence.unsqueeze(0)).logits  # Shape: [1, sequence_length, Bert_vocab_size]
            out_tune = model_tuned(input_ids = tokenized_masked_sentence.unsqueeze(0)).logits  # Shape: [1, sequence_length, Bert_vocab_size]

            out_base = out_base.squeeze(0)  # Shape: [sequence_length, Bert_vocab_size]
            out_tune = out_tune.squeeze(0)  # Shape: [sequence_length, Bert_vocab_size]

            masked_tokens = [idx for idx, word in enumerate(masked_sentence) if word == tokenizer.mask_token]

            for j in masked_tokens:
                predicted_idx_base = torch.argmax(out_base[j]).item()
                predicted_idx_tune = torch.argmax(out_tune[j]).item()

                predicted_word_base = tokenizer.convert_ids_to_tokens(predicted_idx_base)
                predicted_word_tune = tokenizer.convert_ids_to_tokens(predicted_idx_tune)

                # print(f'predicted_word_base: {predicted_word_base}')
                # print(f'predicted_word_help: {predicted_word_tune}')
                # print(f'sentence[{j}]: {sentence[j]}')

                k = int(predicted_word_base == sentence[j])
                m = int(predicted_word_tune == sentence[j])
                S[k][m] += 1

    # print(f'S: {S}')
    # try:
    B = (S[0][1] - S[1][0]) / (S[0][0] + S[1][1] + S[0][1] + S[1][0])
    # except ZeroDivisionError:
    #   B = 0.0

    return B



In [7]:
def get_word_lengths(dataset, tokenizer, l_min = 4):
    word_lengths = {}
    all_tokens = []

    for sample in dataset:
        summary = sample['summary']
        preprocessed_result = tokenizer(summary, 
                                        add_special_tokens = False,
                                        truncation = True,
                                        max_length = 512,
                                        padding = False,
                                        return_attention_mask = False)
        tokens = preprocessed_result["input_ids"]
        decoded_tokens = tokenizer.convert_ids_to_tokens(tokens)
        for token in tokens:
            if token not in all_tokens:
                all_tokens.append(token)

        i = 0
        while i < len(tokens):
            eligible = False
            if decoded_tokens[i].startswith('##'):
                eligible = True
                word_lengths[tokens[i - 1]] = eligible
                word_lengths[tokens[i]] = eligible
            else:
                if len(decoded_tokens[i]) >= l_min:
                    eligible = True
                word_lengths[tokens[i]] = eligible
            i += 1

    assert len(all_tokens) == len(word_lengths), "Association of tokens with word length : FAILED."

    return word_lengths

word_lengths = get_word_lengths(DailyNews_ds, tokenizer)
len(word_lengths)

5027

In [8]:
def training(set_tune, epochs = 10, device = DEVICE):
    model_tuned = BertForMaskedLM.from_pretrained('bert-base-uncased').to(DEVICE)
    optimizer = AdamW(model_tuned.parameters(), lr = 1e-4)
    model_tuned.train()

    inputs = torch.tensor(set_tune['masked_summaries'].tolist(), dtype = torch.long).to(device)
    label = torch.tensor(set_tune['summary'].tolist(), dtype = torch.long).to(device)
    # print(inputs.size())
    
    for epochs in range(epochs):
        outputs = model_tuned(input_ids = inputs, labels = label)
        loss = outputs.loss
        print(loss)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    return model_tuned

In [9]:
def BLANC_tune(summary, text, model, p_mask = 0.15, N = 10, epochs = 10, device = DEVICE):
    N_summary = len(summary)
    # N_summary = len([word for word in summary if not word.startswith('##')])
    N_mask = int(N_summary*p_mask)
    set_tune = pd.DataFrame(columns = ['masked_summaries', 'summary'])

    tokenized_summary = tokenizer.convert_tokens_to_ids(summary)
    
    for _ in range(0, N):
        pos = [i for i, token in enumerate(tokenized_summary) if token in word_lengths and word_lengths[token]]
        random.shuffle(pos)
        while len(pos) != 0:
            masked_summary = tokenized_summary.copy()
            for pos_to_mask in pos[:N_mask]:
                masked_summary[pos_to_mask] = tokenizer.mask_token_id
            set_tune.loc[set_tune.shape[0]] = [masked_summary, tokenized_summary]
            pos = pos[N_mask:]

    # print(set_tune)
    model_tuned = None
    del model_tuned
    model_tuned = training(set_tune, epochs, device)    
    accuracy = BLANC_help_modified(text, model, model_tuned)

    return accuracy

In [10]:
summaries = DailyNews_ds['summary'] # (List[str])
texts = DailyNews_ds['text']  # (List[str]) each string is a paragraph made of a few sentences

In [11]:
# each text in texts is a list of sentences (each sentence is a string)
texts = [sent_tokenize(text.strip()) for text in texts] # List[List[str]]
assert len(texts) == len(summaries) == 300

In [12]:
tokenized_texts = [[tokenizer.tokenize(sentence) for sentence in text] for text in texts]
tokenized_summaries = [tokenizer.tokenize(summary) for summary in summaries]

In [13]:
import torch

def print_gpu_memory():
    print("GPU Memory Summary:")
    print(torch.cuda.memory_summary(device=None, abbreviated=False))
    print("\n")

# Call this function to print GPU memory usage
print_gpu_memory()


GPU Memory Summary:
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 428904 KiB | 428904 KiB | 428904 KiB |      0 B   |
|       from large pool | 428288 KiB | 428288 KiB | 428288 KiB |      0 B   |
|       from small pool |    616 KiB |    616 KiB |    616 KiB |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         | 428904 KiB | 428904 KiB | 428904 KiB |      0 B   |
|       from large pool | 428288 KiB | 428288 KiB | 428288 KiB |      0 B   |
|       from small pool |    616 KiB |    616 KiB |    616 KiB |      0 B   |
|-------------------------------------------

In [14]:
# torch.cuda.empty_cache()

scores = []
for summary, text in tqdm(zip(tokenized_summaries, tokenized_texts)):
    score = BLANC_tune(summary, text, model)
    print(score)
    scores.append(score)
scores

0it [00:00, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.weight', 'bert.pooler.dense.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor(1.7869, device='cuda:0', grad_fn=<NllLossBackward0>)


KeyboardInterrupt: 

In [15]:
print_gpu_memory()

GPU Memory Summary:
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   1835 MiB |   6655 MiB |  16218 MiB |  14382 MiB |
|       from large pool |   1833 MiB |   6653 MiB |  16213 MiB |  14380 MiB |
|       from small pool |      1 MiB |      2 MiB |      4 MiB |      2 MiB |
|---------------------------------------------------------------------------|
| Active memory         |   1835 MiB |   6655 MiB |  16218 MiB |  14382 MiB |
|       from large pool |   1833 MiB |   6653 MiB |  16213 MiB |  14380 MiB |
|       from small pool |      1 MiB |      2 MiB |      4 MiB |      2 MiB |
|-------------------------------------------

In [16]:
torch.cuda.empty_cache()
print_gpu_memory()

GPU Memory Summary:
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   1835 MiB |   6655 MiB |  16218 MiB |  14382 MiB |
|       from large pool |   1833 MiB |   6653 MiB |  16213 MiB |  14380 MiB |
|       from small pool |      1 MiB |      2 MiB |      4 MiB |      2 MiB |
|---------------------------------------------------------------------------|
| Active memory         |   1835 MiB |   6655 MiB |  16218 MiB |  14382 MiB |
|       from large pool |   1833 MiB |   6653 MiB |  16213 MiB |  14380 MiB |
|       from small pool |      1 MiB |      2 MiB |      4 MiB |      2 MiB |
|-------------------------------------------