In [161]:
import pandas as pd
import random
from transformers import BertTokenizer, BertForMaskedLM, AdamW
from nltk.tokenize import sent_tokenize
import nltk
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import torch
import copy
import time
import unicodedata
import re

nltk.download('punkt')
random.seed(42)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Liora\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [183]:
filename = "CNN_DailyMail_555.json"
# filename = "DailyNews_300.json"
data = pd.read_json('../datasets/' + filename)
print(data.shape)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)

(555, 4)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'cls.seq_relationship.weight', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [184]:
class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.features = self.dataset.columns

    def __len__(self):
        return self.dataset.shape[0]

    def __getitem__(self, idx):
        return (self.dataset.iloc[idx, 0], 
                self.dataset.iloc[idx, 1], 
                self.dataset.iloc[idx, 2], 
                self.dataset.iloc[idx, 3])
    
    def map(self, preprocessing_fn, **kwargs):
        return CustomDataset(self.dataset.apply(lambda x: preprocessing_fn(x, **kwargs), axis = 1))
    
    def select_columns(self, columns):
        new_dataset = self.dataset[columns] 
        return CustomDataset(new_dataset)
    
    def get_sentences(self):
        self.dataset['sentences'] = self.dataset['text'].apply(lambda x: sent_tokenize(x))
        return CustomDataset(self.dataset)
    
    # Data cleaning
    def preprocess_text(self, text: str) -> str:
        # lower case
        text = text.lower()
    
        # before normalization : manual handling of contractions and line breaks
        text = text.replace('\n', ' ')
        text = text.replace(' \' ', '\'')
        text = text.replace('\'', '')
    
        # string normalization.
        text = unicodedata.normalize('NFD', text).encode('ascii', 'ignore')
        text = str(text)[2:-1]
        # the result of previous line adds a few characters to the string,
        # we remove them.
    
        # remove non alpha numeric characters, except dots, question and exclamation marks that will be needed to separate sentences.
        text = re.sub(r'[^\w]', ' ', text)
    
        # replace numbers by the <NUM> token.
        text = re.sub(r'[0-9]+', '<NUM>', text)
    
        # remove double whitespaces.
        text = re.sub(r'( ){2,}', ' ', text).strip()
        # removing spaces at beginning and end of string.
    
        return text
    
    def apply_preprocess(self):
        self.dataset["summary"] = self.dataset['summary'].apply(lambda x: self.preprocess_text(x))
        self.dataset["text"] = self.dataset['text'].apply(lambda x: self.preprocess_text(x))
        new_sentences_col = []
        for sentences in self.dataset['sentences']:
            new_sentences_col.append([self.preprocess_text(sentence) for sentence in sentences])
        self.dataset['sentences'] = new_sentences_col
        return CustomDataset(self.dataset)
    
    def random_words_summary(self, summary):
        random_summary = ""
        summary = summary.split()
        for _ in range(len(summary)):
            random_summary += random.choice(summary) + ' '
        return random_summary
    
    def apply_random_words_summary(self):
        self.dataset['random_summary'] = self.dataset['summary'].apply(lambda x: self.random_words_summary(x))
        return CustomDataset(self.dataset)
    
dataset = CustomDataset(data)
dataset = dataset.get_sentences()
# print(dataset.__getitem__(0)[2])
# print(dataset.__getitem__(0)[3])
# print(dataset.__getitem__(0)[4])
dataset = dataset.apply_preprocess()
dataset = dataset.apply_random_words_summary()
# print(dataset.__getitem__(0)[2])
# print(dataset.__getitem__(0)[3])
# print(dataset.__getitem__(0)[4])

In [185]:
def get_word_lengths(dataset, tokenizer, l_min = 4):
    word_lengths = {}
    all_tokens = []

    for sample in dataset:
        summary = sample[0]
        preprocessed_result = tokenizer(summary, 
                                        add_special_tokens = False,
                                        truncation = True,
                                        max_length = 512,
                                        padding = False,
                                        return_attention_mask = False)
        tokens = preprocessed_result["input_ids"]
        decoded_tokens = tokenizer.convert_ids_to_tokens(tokens)
        for token in tokens:
            if token not in all_tokens:
                all_tokens.append(token)

        i = 0
        while i < len(tokens):
            eligible = False
            if decoded_tokens[i].startswith('##'):
                eligible = True
                word_lengths[tokens[i - 1]] = eligible
                word_lengths[tokens[i]] = eligible
            else:
                if len(decoded_tokens[i]) >= l_min:
                    eligible = True
                word_lengths[tokens[i]] = eligible
            i += 1

    assert len(all_tokens) == len(word_lengths), "Association of tokens with word length : FAILED."

    return word_lengths

In [186]:
def preprocessing_fn(x, tokenizer):
    x["summary_ids"] = tokenizer(
        x["summary"],
        add_special_tokens = False,
        truncation = True,
        max_length = 512,
        padding = False,
        return_attention_mask = True,
    )["input_ids"]

    x["text_ids"] = tokenizer(
        x["text"],
        add_special_tokens = False,
        truncation = True,
        max_length = 512,
        padding = False,
        return_attention_mask = True,
    )["input_ids"]

    x["sentences_ids"] = tokenizer(
        x["sentences"],
        add_special_tokens = False,
        truncation = True,
        max_length = 512,
        padding = False,
        return_attention_mask = True,
    )["input_ids"]

    x["random_summary_ids"] = tokenizer(
        x["random_summary"],
        add_special_tokens = False,
        truncation = True,
        max_length = 512,
        padding = False,
        return_attention_mask = True,
    )["input_ids"]

    return x

splitted_dataset = dataset.select_columns(["summary", "text", "sentences", "random_summary"])

word_lengths = get_word_lengths(splitted_dataset, tokenizer)

# Tokenize the dataset
splitted_dataset = splitted_dataset.map(
    preprocessing_fn, tokenizer = tokenizer
)

# Remove useless columns
splitted_dataset = splitted_dataset.select_columns(["summary_ids", "text_ids", "sentences_ids", "random_summary_ids"])
print(splitted_dataset.__getitem__(0)[0]) # summary
print(splitted_dataset.__getitem__(0)[1]) # text
print(splitted_dataset.__getitem__(0)[2]) # sentences
print(splitted_dataset.__getitem__(0)[3]) # random_summary

[2079, 3501, 2089, 2145, 3288, 5571, 2114, 2577, 27946, 2079, 3501, 2071, 3288, 2976, 2942, 5571, 2114, 27946, 2942, 2916, 5160, 3421, 11851, 17789, 19953, 3008]
[1996, 2533, 1997, 3425, 2089, 2145, 3288, 2942, 5571, 2114, 11851, 17789, 3235, 13108, 2577, 27946, 4905, 2236, 4388, 9111, 3936, 2006, 9432, 9111, 2409, 12060, 2006, 9432, 2008, 2079, 3501, 2018, 2025, 5531, 2049, 4812, 2046, 1996, 2337, 1026, 16371, 2213, 1028, 1026, 16371, 2213, 1028, 5043, 1998, 1996, 3043, 2003, 7552, 2045, 2024, 3161, 4084, 2008, 2057, 2024, 2145, 1999, 1996, 2832, 1997, 2635, 9111, 2056, 2429, 2000, 1996, 2940, 2045, 2024, 9390, 2040, 2057, 2215, 2000, 3713, 2000, 2004, 1037, 2765, 1997, 2070, 3522, 8973, 17186, 2091, 2005, 2678, 4905, 2236, 4388, 9111, 2409, 12060, 7483, 2008, 1996, 3425, 7640, 4812, 2046, 11851, 17789, 3235, 13108, 2577, 27946, 2003, 7552, 2372, 1997, 1996, 2047, 2259, 2103, 2473, 4929, 7415, 2666, 28095, 2015, 1999, 3638, 1997, 11851, 17789, 3235, 15885, 2006, 1996, 13082, 2604, 200

In [187]:
def collate_fn(batch):
    # Get the tokenized sequences for each item in the batch
    text_ids_batch = [torch.tensor(item[1], dtype = torch.int) for item in batch]
    summary_ids_batch = [torch.tensor(item[0], dtype = torch.int) for item in batch]
    sentences_ids_batch = [
        [torch.tensor(sentence, dtype = torch.int) for sentence in item[2]]
        for item in batch
    ]

    # Pad sequences to the maximum length in the batch
    padded_text_ids = pad_sequence([torch.cat([item, torch.zeros(max(0, 512 - len(item)))]) for item in text_ids_batch], batch_first = True, padding_value = 0)
    padded_summary_ids = pad_sequence([torch.cat([item, torch.zeros(max(0, 512 - len(item)))]) for item in summary_ids_batch], batch_first = True, padding_value = 0)
    padded_sentences_ids = [
        pad_sequence(
            [torch.cat([sentence, torch.zeros(max(0, 512 - len(sentence)), dtype = torch.int)]) for sentence in item],
            batch_first = True,
            padding_value = 0
        )
        for item in sentences_ids_batch
    ]

    return {"text_ids": padded_text_ids, "summary_ids": padded_summary_ids, "sentences_ids": padded_sentences_ids}

batch_size = 32
dataloader = DataLoader(splitted_dataset, batch_size = batch_size, collate_fn = collate_fn)

epochs = 3
def training(summary, text, model, epochs = 10):
    model_copy = copy.deepcopy(model)
    model_copy.train()

    summary = summary.unsqueeze(0)
    text = text.unsqueeze(0)
    if summary.size(1) != text.size(1):
        raise RuntimeError("Sizes along the sequence length dimension must match.")
    
    for epochs in range(epochs):
        whole_input = torch.cat((summary, text), dim = 0).long()
        outputs = model_copy(whole_input, labels = whole_input)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    return model_copy

In [188]:
def mask_sentence(sentence, i, M, word_lengths):
    tokenized_sentence = []
    masked_token_ids = []

    for j in range(len(sentence)):
        if (j - i) % M == 0 and sentence[j].item() in word_lengths and word_lengths[sentence[j].item()]:
            tokenized_sentence.append(tokenizer.mask_token_id) # 103
            masked_token_ids.append(j)
        else:
            tokenized_sentence.append(sentence[j].item())
            
    tokenized_sentence= torch.tensor(tokenized_sentence)
    
    return tokenized_sentence, masked_token_ids

In [189]:
def modified_BLANC_help(sentences, model, model_tuned, p_mask = 0.15):
    S = [[0, 0], [0, 0]]
    M = int(1/p_mask)
    
    for sentence in sentences:
        for i in range(M):
            masked_sentence, masked_tokens_ids = mask_sentence(sentence, i, M, word_lengths)
            masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()
            sentence = sentence.long()

            out_base = model(masked_sentence, labels = sentence).logits
            out_help = model_tuned(masked_sentence, labels = sentence).logits
            # print(out_base.shape)
            # print(out_base[0][0].shape)

            for j in masked_tokens_ids:
                predicted_token_model = torch.argmax(out_base[0][j])
                predicted_token_model_tuned = torch.argmax(out_help[0][j])
                
                k = int(predicted_token_model == sentence[j])
                m = int(predicted_token_model_tuned == sentence[j])
                S[k][m] += 1
        break
    try:
      B = (S[0][1] - S[1][0]) / (S[0][0] + S[1][1] + S[0][1] + S[1][0])
    except ZeroDivisionError:
      B = 0.0
    
    return B

In [190]:
def blanc_tune(summary, text, model, p_mask = 0.15, N = 10, epochs = 10):
    N_summary = len(summary[:summary.tolist().index(0)])
    N_mask = int(N_summary*p_mask)
    set_tune = pd.DataFrame(columns = ['summary', 'text'])

    for j in range(0, N):
        pos = [i for i, token in enumerate(summary.tolist()) if token in word_lengths and word_lengths[token]]
        random.shuffle(pos)
        # print(len(pos), pos)
        # print(N_mask)
        while len(pos) != 0:
            masked_summary = summary.tolist().copy()
            for pos_to_mask in pos[:N_mask]:
                masked_summary[pos_to_mask] = '[MASK]'
                set_tune.loc[set_tune.shape[0]] = [masked_summary, text]
            pos = pos[N_mask:]

    model_tuned = training(summary, text, model, epochs)
    print('\n')      
    return model_tuned

def blanc_tune_batch(batch, model, p_mask = 0.15, N = 10, epochs = 10):
    batch_tuned_models = []
    batch_accuracies = []
    
    i = 0
    for summary, text, sentences in zip(batch['summary_ids'], batch['text_ids'], batch['sentences_ids']):
        print(f"Summary {i} of batch")
        i += 1
        start_time = time.time()
        model_tuned = blanc_tune(summary, text, model, p_mask, N, epochs)
        batch_tuned_models.append(model_tuned)
        accuracy = modified_BLANC_help(sentences, model, model_tuned)
        print(accuracy)
        end_time = time.time()
        batch_accuracies.append(accuracy)
        elapsed_time = end_time - start_time
        print(f"Elapsed Time: {elapsed_time} seconds")

    return batch_tuned_models, batch_accuracies

for batch in dataloader:
    models_tuned, text_accuracies = blanc_tune_batch(batch, model, epochs = epochs)

Summary 0 of batch
