In [4]:
import pandas as pd
import random
from transformers import BertTokenizer, BertForMaskedLM, AdamW
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import torch
import copy
import time

random.seed(42)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
data = pd.read_json('../datasets/DailyNews_300.json')
print(data.shape)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)

(300, 4)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'bert.pooler.dense.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.features = self.dataset.columns

    def __len__(self):
        return self.dataset.shape[0]

    def __getitem__(self, idx):
        return (self.dataset.iloc[idx, 0], self.dataset.iloc[idx, 1], self.dataset.iloc[idx, 2])
    
    def map(self, preprocessing_fn, **kwargs):
        return CustomDataset(self.dataset.apply(lambda x: preprocessing_fn(x, **kwargs), axis = 1))
    
    def select_columns(self, columns):
        new_dataset = self.dataset[columns] 
        return CustomDataset(new_dataset)
    
    def get_sentences(self):
        self.dataset['sentences'] = self.dataset['text'].apply(lambda x: x.split('.'))
        return CustomDataset(self.dataset)
    
dataset = CustomDataset(data)
dataset = dataset.get_sentences()
# print(dataset.__getitem__(0))

In [7]:
def get_word_lengths(dataset, tokenizer):
    word_lengths = {}
    all_tokens = []

    for sample in dataset:
        summary = sample[0]
        preprocessed_result = tokenizer(summary, 
                                        add_special_tokens = False,
                                        truncation = True,
                                        max_length = 512,
                                        padding = False,
                                        return_attention_mask = False)
        tokens = preprocessed_result["input_ids"]
        decoded_tokens = tokenizer.convert_ids_to_tokens(tokens)
        for token in tokens:
            if token not in all_tokens:
                all_tokens.append(token)

        i = 0
        while i < len(tokens):
            if decoded_tokens[i].startswith('##'):
                combined_word = decoded_tokens[i - 1] + decoded_tokens[i][2:]
                word_lengths[tokens[i - 1]] = len(combined_word)
                word_lengths[tokens[i]] = len(combined_word)
            else:
                word_lengths[tokens[i]] = len(decoded_tokens[i])
            i += 1

    assert len(all_tokens) == len(word_lengths), "Association of tokens with word length : FAILED."

    return word_lengths

In [8]:
def preprocessing_fn(x, tokenizer):
    x["summary_ids"] = tokenizer(
        x["summary"],
        add_special_tokens = False,
        truncation = True,
        max_length = 512,
        padding = False,
        return_attention_mask = True,
    )["input_ids"]

    x["text_ids"] = tokenizer(
        x["text"],
        add_special_tokens = False,
        truncation = True,
        max_length = 512,
        padding = False,
        return_attention_mask = True,
    )["input_ids"]

    x["sentences_ids"] = tokenizer(
        x["sentences"],
        add_special_tokens = False,
        truncation = True,
        max_length = 512,
        padding = False,
        return_attention_mask = True,
    )["input_ids"]

    return x

splitted_dataset = dataset.select_columns(["summary", "text", "sentences"])
# print(splitted_dataset.__getitem__(0))

word_lengths = get_word_lengths(splitted_dataset, tokenizer)

# Tokenize the dataset
splitted_dataset = splitted_dataset.map(
    preprocessing_fn, tokenizer = tokenizer
)
print(splitted_dataset.__getitem__(0)[2])

# Remove useless columns
splitted_dataset = splitted_dataset.select_columns(["summary_ids", "text_ids", "sentences_ids"])
print(splitted_dataset.__getitem__(0)[2])

['Mario Mandzukic pounces to fire the ball past Jordan Pickford and put Croatia into the World Cup final', ' Photo: Reuters Independent', "ie \n \nFormer England defender Gary Neville suggested Gareth Southgate's squad had done more than could have been expected of them at this World Cup as they bowed out with a semi-final defeat against Croatia", ' \n \nhttps://www', 'independent', 'ie/sport/soccer/world-cup-2018/gary-neville-salutes-englands-overachievers-as-alan-shearer-and-rio-ferdinand-give-their-verdicts-37108667', 'html \n \nhttps://www', 'independent', 'ie/incoming/article37108634', 'ece/7571a/AUTOCROP/h342/52Man1', "jpg \n   Email     \nFormer England defender Gary Neville suggested Gareth Southgate's squad had done more than could have been expected of them at this World Cup as they bowed out with a semi-final defeat against Croatia", ' \n  \nA jaded England faded after had Kieran Trippier fired England ahead after just five minutes with a superb free-kick, with goals from Iv

In [9]:
def collate_fn(batch):
    # Get the tokenized sequences for each item in the batch
    text_ids_batch = [torch.tensor(item[1], dtype = torch.int) for item in batch]
    summary_ids_batch = [torch.tensor(item[0], dtype = torch.int) for item in batch]
    sentences_ids_batch = [
        [torch.tensor(sentence, dtype = torch.int) for sentence in item[2]]
        for item in batch
    ]

    # Pad sequences to the maximum length in the batch
    padded_text_ids = pad_sequence([torch.cat([item, torch.zeros(max(0, 512 - len(item)))]) for item in text_ids_batch], batch_first = True, padding_value = 0)
    padded_summary_ids = pad_sequence([torch.cat([item, torch.zeros(max(0, 512 - len(item)))]) for item in summary_ids_batch], batch_first = True, padding_value = 0)
    padded_sentences_ids = [
        pad_sequence(
            [torch.cat([sentence, torch.zeros(max(0, 512 - len(sentence)), dtype = torch.int)]) for sentence in item],
            batch_first = True,
            padding_value = 0
        )
        for item in sentences_ids_batch
    ]

    return {"text_ids": padded_text_ids, "summary_ids": padded_summary_ids, "sentences_ids": padded_sentences_ids}

batch_size = 32
dataloader = DataLoader(splitted_dataset, batch_size = batch_size, collate_fn = collate_fn)

epochs = 3
def training(summary, text, model, epochs = 10):
    model_copy = copy.deepcopy(model)
    model_copy.train()

    summary = summary.unsqueeze(0)
    text = text.unsqueeze(0)
    if summary.size(1) != text.size(1):
        raise RuntimeError("Sizes along the sequence length dimension must match.")
    
    for epochs in range(epochs):
        whole_input = torch.cat((summary, text), dim = 0).long()
        outputs = model_copy(whole_input, labels = whole_input)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    return model_copy

In [10]:
def mask_sentence(sentence, i, M, word_lengths, l_min):
    tokenized_sentence = []
    masked_token_ids = []

    for j in range(len(sentence)):
        if (j - i) % M == 0 and sentence[j].item() in word_lengths and word_lengths[sentence[j].item()] >= l_min:
            tokenized_sentence.append(tokenizer.mask_token_id)
            masked_token_ids.append(j)
        else:
            tokenized_sentence.append(sentence[j].item())
            
    tokenized_sentence= torch.tensor(tokenized_sentence)
    
    return tokenized_sentence, masked_token_ids

In [11]:
def modified_BLANC_help(sentences, model, model_tuned, p_mask = 0.15, l_min = 4):
    S = [[0, 0], [0, 0]]
    M = int(1/p_mask)
    
    for sentence in sentences:
        for i in range(M):
            masked_sentence, masked_tokens_ids = mask_sentence(sentence, i, M, word_lengths, l_min)
            masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()
            sentence = sentence.long()

            out_base = model(masked_sentence, labels = sentence).logits
            out_help = model_tuned(masked_sentence, labels = sentence).logits

            for j in masked_tokens_ids:
                predicted_token_model = torch.argmax(out_base[0][j])
                predicted_token_model_tuned = torch.argmax(out_help[0][j])
                
                k = int(predicted_token_model == sentence[j])
                m = int(predicted_token_model_tuned == sentence[j])
                S[k][m] += 1
        break
    try:
      B = (S[0][1] - S[1][0]) / (S[0][0] + S[1][1] + S[0][1] + S[1][0])
    except ZeroDivisionError:
      B = 0.0
    
    return B

In [15]:
def blanc_tune(summary, text, model, p_mask = 0.15, l_min = 4, N = 10, epochs = 10):
    N_summary = len(summary[:summary.tolist().index(0)])
    N_mask = int(N_summary*p_mask)
    set_tune = pd.DataFrame(columns = ['summary', 'text'])

    for j in range(0, N):
        pos = [i for i, token in enumerate(summary.tolist()) if token in word_lengths and word_lengths[token] >= l_min]
        random.shuffle(pos)
        # print(len(pos), pos)
        # print(N_mask)
        while len(pos) != 0:
            masked_summary = summary.tolist().copy()
            for pos_to_mask in pos[:N_mask]:
                masked_summary[pos_to_mask] = '[MASK]'
                set_tune.loc[set_tune.shape[0]] = [masked_summary, text]
            pos = pos[N_mask:]

    model_tuned = training(summary, text, model, epochs)
    print('\n')      
    return model_tuned

def blanc_tune_batch(batch, model, p_mask = 0.15, l_min = 4, N = 10, epochs = 10):
    batch_tuned_models = []
    batch_accuracies = []
    
    i = 0
    for summary, text, sentences in zip(batch['summary_ids'], batch['text_ids'], batch['sentences_ids']):
        print(f"Summary {i} of batch")
        i += 1
        start_time = time.time()
        model_tuned = blanc_tune(summary, text, model, p_mask, l_min, N, epochs)
        batch_tuned_models.append(model_tuned)
        accuracy = modified_BLANC_help(sentences, model, model_tuned)
        print(accuracy)
        end_time = time.time()
        batch_accuracies.append(accuracy)
        elapsed_time = end_time - start_time
        print(f"Elapsed Time: {elapsed_time} seconds")

    return batch_tuned_models, batch_accuracies

for batch in dataloader:
    models_tuned, text_accuracies = blanc_tune_batch(batch, model, epochs = epochs)

Summary 0 of batch
67 [18, 87, 42, 37, 7, 102, 78, 111, 19, 88, 11, 51, 43, 27, 23, 83, 34, 66, 80, 109, 24, 32, 14, 68, 103, 93, 116, 71, 35, 106, 4, 59, 29, 95, 13, 0, 21, 97, 41, 86, 38, 99, 69, 15, 114, 92, 62, 49, 110, 16, 98, 1, 2, 39, 60, 5, 54, 107, 74, 6, 84, 8, 118, 44, 56, 3, 17]
18
67 [1, 19, 29, 38, 0, 111, 44, 43, 59, 5, 39, 99, 49, 11, 14, 41, 7, 24, 8, 88, 66, 98, 118, 80, 51, 93, 109, 102, 23, 6, 83, 92, 37, 60, 2, 87, 3, 27, 17, 114, 116, 97, 35, 42, 84, 18, 54, 106, 110, 62, 4, 69, 74, 78, 21, 71, 16, 32, 107, 13, 34, 95, 68, 103, 56, 86, 15]
18
67 [118, 32, 4, 102, 17, 41, 16, 2, 114, 78, 37, 106, 95, 92, 34, 1, 80, 44, 43, 3, 38, 18, 71, 88, 68, 66, 84, 111, 5, 110, 60, 11, 15, 59, 24, 6, 109, 14, 49, 103, 99, 107, 13, 39, 23, 93, 29, 69, 21, 86, 54, 116, 7, 83, 74, 0, 98, 56, 19, 51, 42, 62, 35, 87, 8, 97, 27]
18
67 [24, 110, 42, 11, 8, 62, 5, 83, 15, 88, 13, 78, 43, 99, 35, 19, 38, 14, 32, 27, 109, 6, 86, 116, 98, 54, 118, 23, 107, 59, 103, 16, 44, 21, 97, 49, 11

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.






  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 18.041014194488525 seconds
Summary 1 of batch
53 [64, 85, 61, 18, 35, 59, 66, 13, 25, 81, 9, 27, 83, 10, 5, 73, 50, 56, 63, 8, 67, 7, 20, 32, 4, 2, 60, 40, 29, 19, 57, 72, 12, 49, 14, 37, 33, 80, 3, 69, 39, 28, 71, 52, 23, 62, 76, 58, 38, 36, 79, 30, 34]
13
53 [37, 61, 57, 76, 32, 67, 64, 59, 80, 83, 39, 52, 18, 58, 69, 9, 28, 71, 33, 40, 2, 29, 38, 79, 8, 36, 5, 85, 56, 62, 30, 14, 35, 27, 49, 20, 60, 10, 50, 12, 72, 13, 63, 73, 19, 25, 3, 4, 81, 7, 66, 23, 34]
13
53 [3, 58, 60, 27, 69, 80, 30, 62, 81, 32, 5, 76, 36, 8, 20, 25, 73, 59, 19, 40, 34, 37, 23, 71, 4, 10, 14, 12, 29, 79, 72, 61, 63, 50, 28, 57, 18, 85, 13, 49, 64, 83, 39, 9, 52, 38, 2, 33, 66, 56, 7, 35, 67]
13
53 [19, 62, 56, 76, 4, 64, 2, 69, 33, 60, 7, 32, 80, 36, 27, 63, 58, 39, 9, 5, 10, 49, 13, 12, 61, 67, 18, 14, 28, 57, 20, 34, 72, 35, 30, 81, 37, 23, 29, 85, 3, 40, 79, 66, 50, 71, 83, 59, 52, 38, 73, 25, 8]
13
53 [76, 62, 20, 8, 34, 10, 35, 66, 23, 14, 27, 63, 57, 85, 37, 56, 13, 32, 7, 49, 61, 36

  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.12
Elapsed Time: 17.784279346466064 seconds
Summary 2 of batch
51 [51, 1, 26, 37, 39, 32, 42, 30, 49, 62, 4, 69, 72, 34, 28, 7, 3, 64, 15, 6, 35, 47, 8, 31, 17, 29, 60, 46, 74, 66, 2, 77, 27, 55, 0, 61, 57, 22, 41, 20, 9, 56, 44, 21, 63, 71, 33, 59, 5, 13, 10]
11
51 [41, 34, 66, 33, 44, 30, 0, 56, 35, 3, 51, 55, 7, 1, 63, 42, 9, 21, 29, 37, 47, 69, 61, 13, 26, 4, 62, 46, 17, 49, 10, 77, 5, 64, 32, 71, 39, 28, 2, 6, 8, 22, 59, 15, 60, 72, 31, 57, 20, 74, 27]
11
51 [44, 57, 34, 27, 30, 0, 5, 66, 2, 69, 62, 8, 46, 61, 10, 13, 49, 15, 37, 26, 7, 31, 39, 21, 9, 35, 64, 47, 3, 17, 55, 28, 77, 4, 59, 29, 51, 1, 42, 20, 6, 74, 60, 33, 56, 22, 72, 71, 63, 32, 41]
11
51 [47, 46, 59, 30, 20, 42, 28, 13, 8, 22, 60, 66, 0, 9, 55, 35, 69, 37, 29, 33, 62, 56, 2, 63, 27, 61, 4, 51, 10, 57, 44, 1, 72, 5, 41, 7, 15, 26, 34, 3, 31, 49, 77, 6, 74, 21, 71, 39, 17, 32, 64]
11
51 [61, 27, 60, 37, 57, 62, 26, 10, 2, 5, 33, 41, 77, 66, 59, 63, 17, 15, 74, 49, 0, 71, 44, 9, 64, 30, 34, 51, 69, 28, 72, 20, 55

  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.08333333333333333
Elapsed Time: 17.7088725566864 seconds
Summary 3 of batch
27 [34, 7, 32, 23, 18, 54, 2, 46, 40, 12, 6, 26, 53, 38, 22, 55, 33, 31, 35, 42, 51, 44, 5, 25, 28, 57, 3]
8
27 [32, 51, 53, 6, 34, 5, 23, 26, 2, 25, 38, 31, 18, 7, 55, 40, 28, 54, 42, 57, 46, 33, 22, 3, 12, 44, 35]
8
27 [6, 35, 22, 2, 32, 33, 51, 55, 12, 18, 23, 34, 3, 44, 38, 54, 31, 25, 28, 57, 40, 5, 26, 7, 42, 46, 53]
8
27 [57, 6, 38, 53, 31, 33, 46, 54, 51, 42, 23, 32, 40, 3, 5, 7, 22, 44, 34, 26, 18, 12, 25, 55, 2, 28, 35]
8
27 [2, 40, 46, 18, 53, 26, 5, 34, 23, 6, 35, 38, 31, 57, 25, 44, 22, 28, 32, 3, 33, 54, 42, 51, 55, 12, 7]
8
27 [53, 40, 44, 6, 5, 55, 2, 38, 57, 28, 34, 46, 42, 22, 26, 3, 23, 7, 18, 12, 32, 51, 35, 31, 25, 54, 33]
8
27 [26, 7, 25, 53, 54, 23, 42, 38, 34, 44, 51, 12, 40, 18, 6, 46, 35, 31, 3, 55, 2, 28, 33, 22, 5, 32, 57]
8
27 [53, 44, 35, 12, 22, 31, 26, 3, 18, 28, 57, 46, 38, 32, 6, 23, 40, 5, 2, 42, 33, 55, 54, 7, 51, 34, 25]
8
27 [23, 51, 55, 26, 3, 53, 31, 54, 12, 25, 40, 18,

KeyboardInterrupt: 