In [1]:
import pandas as pd
import random
from transformers import BertTokenizer, BertForMaskedLM, AdamW
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import torch
import copy
import time

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
data = pd.read_json('../datasets/DailyNews_300.json')
print(data.shape)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)

(300, 4)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:
class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.features = self.dataset.columns

    def __len__(self):
        return self.dataset.shape[0]

    def __getitem__(self, idx):
        return (self.dataset.iloc[idx, 0], self.dataset.iloc[idx, 1], self.dataset.iloc[idx, 2])
    
    def map(self, preprocessing_fn, **kwargs):
        return CustomDataset(self.dataset.apply(lambda x: preprocessing_fn(x, **kwargs), axis = 1))
    
    def select_columns(self, columns):
        new_dataset = self.dataset[columns] 
        return CustomDataset(new_dataset)
    
    def get_sentences(self):
        self.dataset['sentences'] = self.dataset['text'].apply(lambda x: x.split('.'))
        return CustomDataset(self.dataset)
    
dataset = CustomDataset(data)
dataset = dataset.get_sentences()
# print(dataset.__getitem__(0))

In [18]:
def get_word_lengths(dataset, tokenizer):
    word_lengths = {}
    all_tokens = []

    for sample in dataset:
        summary = sample[0]
        preprocessed_result = tokenizer(summary, 
                                        add_special_tokens = False,
                                        truncation = True,
                                        max_length = 512,
                                        padding = False,
                                        return_attention_mask = False)
        tokens = preprocessed_result["input_ids"]
        decoded_tokens = tokenizer.convert_ids_to_tokens(tokens)
        for token in tokens:
            if token not in all_tokens:
                all_tokens.append(token)

        i = 0
        while i < len(tokens):
            if decoded_tokens[i].startswith('##'):
                combined_word = decoded_tokens[i - 1] + decoded_tokens[i][2:]
                word_lengths[tokens[i - 1]] = len(combined_word)
                word_lengths[tokens[i]] = len(combined_word)
            else:
                word_lengths[tokens[i]] = len(decoded_tokens[i])
            i += 1

    assert len(all_tokens) == len(word_lengths), "Association of tokens with word length : FAILED."

    return word_lengths

In [19]:
def preprocessing_fn(x, tokenizer):
    x["summary_ids"] = tokenizer(
        x["summary"],
        add_special_tokens = False,
        truncation = True,
        max_length = 512,
        padding = False,
        return_attention_mask = True,
    )["input_ids"]

    x["text_ids"] = tokenizer(
        x["text"],
        add_special_tokens = False,
        truncation = True,
        max_length = 512,
        padding = False,
        return_attention_mask = True,
    )["input_ids"]

    x["sentences_ids"] = tokenizer(
        x["sentences"],
        add_special_tokens = False,
        truncation = True,
        max_length = 512,
        padding = False,
        return_attention_mask = True,
    )["input_ids"]

    return x

splitted_dataset = dataset.select_columns(["summary", "text", "sentences"])
# print(splitted_dataset.__getitem__(0))

word_lengths = get_word_lengths(splitted_dataset, tokenizer)

# Tokenize the dataset
splitted_dataset = splitted_dataset.map(
    preprocessing_fn, tokenizer = tokenizer
)
print(splitted_dataset.__getitem__(0)[2])

# Remove useless columns
splitted_dataset = splitted_dataset.select_columns(["summary_ids", "text_ids", "sentences_ids"])
print(splitted_dataset.__getitem__(0)[2])

['Mario Mandzukic pounces to fire the ball past Jordan Pickford and put Croatia into the World Cup final', ' Photo: Reuters Independent', "ie \n \nFormer England defender Gary Neville suggested Gareth Southgate's squad had done more than could have been expected of them at this World Cup as they bowed out with a semi-final defeat against Croatia", ' \n \nhttps://www', 'independent', 'ie/sport/soccer/world-cup-2018/gary-neville-salutes-englands-overachievers-as-alan-shearer-and-rio-ferdinand-give-their-verdicts-37108667', 'html \n \nhttps://www', 'independent', 'ie/incoming/article37108634', 'ece/7571a/AUTOCROP/h342/52Man1', "jpg \n   Email     \nFormer England defender Gary Neville suggested Gareth Southgate's squad had done more than could have been expected of them at this World Cup as they bowed out with a semi-final defeat against Croatia", ' \n  \nA jaded England faded after had Kieran Trippier fired England ahead after just five minutes with a superb free-kick, with goals from Iv

In [24]:
def collate_fn(batch):
    # Get the tokenized sequences for each item in the batch
    text_ids_batch = [torch.tensor(item[1], dtype = torch.int) for item in batch]
    summary_ids_batch = [torch.tensor(item[0], dtype = torch.int) for item in batch]
    sentences_ids_batch = [
        [torch.tensor(sentence, dtype = torch.int) for sentence in item[2]]
        for item in batch
    ]

    # Pad sequences to the maximum length in the batch
    padded_text_ids = pad_sequence([torch.cat([item, torch.zeros(max(0, 512 - len(item)))]) for item in text_ids_batch], batch_first = True, padding_value = 0)
    padded_summary_ids = pad_sequence([torch.cat([item, torch.zeros(max(0, 512 - len(item)))]) for item in summary_ids_batch], batch_first = True, padding_value = 0)
    padded_sentences_ids = [
        pad_sequence(
            [torch.cat([sentence, torch.zeros(max(0, 512 - len(sentence)), dtype = torch.int)]) for sentence in item],
            batch_first = True,
            padding_value = 0
        )
        for item in sentences_ids_batch
    ]

    return {"text_ids": padded_text_ids, "summary_ids": padded_summary_ids, "sentences_ids": padded_sentences_ids}

batch_size = 32
dataloader = DataLoader(splitted_dataset, batch_size = batch_size, collate_fn = collate_fn)

epochs = 3
def training(summary, text, model, epochs = 10):
    model_copy = copy.deepcopy(model)
    model_copy.train()

    summary = summary.unsqueeze(0)
    text = text.unsqueeze(0)
    if summary.size(1) != text.size(1):
        raise RuntimeError("Sizes along the sequence length dimension must match.")
    
    for epochs in range(epochs):
        whole_input = torch.cat((summary, text), dim = 0).long()
        outputs = model_copy(whole_input, labels = whole_input)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    return model_copy

In [68]:
def mask_sentence(sentence, i, M, word_lengths, l_min):
    tokenized_sentence = []
    masked_token_ids = []

    for j in range(len(sentence)):
        if (j - i) % M == 0 and sentence[j].item() in word_lengths and word_lengths[sentence[j].item()] >= l_min:
            tokenized_sentence.append(tokenizer.mask_token_id)
            masked_token_ids.append(j)
        else:
            tokenized_sentence.append(sentence[j].item())
            
    tokenized_sentence= torch.tensor(tokenized_sentence)
    
    return tokenized_sentence, masked_token_ids

In [89]:
def modified_BLANC_help(sentences, model, model_tuned, p_mask = 0.15, l_min = 4):
    S = [[0, 0], [0, 0]]
    M = int(1/p_mask)
    
    for sentence in sentences:
        for i in range(M):
            masked_sentence, masked_tokens_ids = mask_sentence(sentence, i, M, word_lengths, l_min)
            masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()
            sentence = sentence.long()

            out_base = model(masked_sentence, labels = sentence).logits
            out_help = model_tuned(masked_sentence, labels = sentence).logits

            for j in masked_tokens_ids:
                predicted_token_model = torch.argmax(out_base[0][j])
                predicted_token_model_tuned = torch.argmax(out_help[0][j])
                
                k = int(predicted_token_model == sentence[j])
                m = int(predicted_token_model_tuned == sentence[j])
                S[k][m] += 1
        break
    try:
      B = (S[0][1] - S[1][0]) / (S[0][0] + S[1][1] + S[0][1] + S[1][0])
    except ZeroDivisionError:
      B = 0.0
    
    return B

In [91]:
def blanc_tune(summary, text, model, p_mask = 0.15, l_min = 4, N = 10, epochs = 10):
    N_summary = len(summary)
    N_mask = int(N_summary*p_mask)
    set_tune = pd.DataFrame(columns = ['summary', 'text'])

    for j in range(0, N):
        pos = [i for i, token in enumerate(summary.tolist()) if token in word_lengths and word_lengths[token] >= l_min]
        random.shuffle(pos)
        while len(pos) != 0:
            masked_summary = summary.tolist().copy()
            for pos_to_mask in pos[:N_mask]:
                masked_summary[pos_to_mask] = '[MASK]'
                set_tune.loc[set_tune.shape[0]] = [masked_summary, text]
            pos = pos[N_mask:]

    model_tuned = training(summary, text, model, epochs)
    print('\n')      
    return model_tuned

def blanc_tune_batch(batch, model, p_mask = 0.15, l_min = 4, N = 10, epochs = 10):
    batch_tuned_models = []
    batch_accuracies = []
    
    i = 0
    for summary, text, sentences in zip(batch['summary_ids'], batch['text_ids'], batch['sentences_ids']):
        print(f"Summary {i} of batch")
        i += 1
        start_time = time.time()
        model_tuned = blanc_tune(summary, text, model, p_mask, l_min, N, epochs)
        batch_tuned_models.append(model_tuned)
        accuracy = modified_BLANC_help(sentences, model, model_tuned)
        print(accuracy)
        end_time = time.time()
        batch_accuracies.append(accuracy)
        elapsed_time = end_time - start_time
        print(f"Elapsed Time: {elapsed_time} seconds")

    return batch_tuned_models, batch_accuracies

for batch in dataloader:
    models_tuned, text_accuracies = blanc_tune_batch(batch, model, epochs = epochs)

Summary 0 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 18.661803245544434 seconds
Summary 1 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.12
Elapsed Time: 24.518627882003784 seconds
Summary 2 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.16666666666666666
Elapsed Time: 18.158257722854614 seconds
Summary 3 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.2
Elapsed Time: 16.86569356918335 seconds
Summary 4 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.2631578947368421
Elapsed Time: 18.028971910476685 seconds
Summary 5 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.043478260869565216
Elapsed Time: 17.49683976173401 seconds
Summary 6 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 18.282581090927124 seconds
Summary 7 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.045454545454545456
Elapsed Time: 19.03317379951477 seconds
Summary 8 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.0625
Elapsed Time: 17.374573945999146 seconds
Summary 9 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.05
Elapsed Time: 16.60593342781067 seconds
Summary 10 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.1
Elapsed Time: 17.074475526809692 seconds
Summary 11 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 17.69163155555725 seconds
Summary 12 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.09523809523809523
Elapsed Time: 18.318114757537842 seconds
Summary 13 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 17.688296794891357 seconds
Summary 14 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.08695652173913043
Elapsed Time: 17.29972767829895 seconds
Summary 15 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.15789473684210525
Elapsed Time: 17.376991033554077 seconds
Summary 16 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 17.198835849761963 seconds
Summary 17 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 16.440516233444214 seconds
Summary 18 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 17.62704873085022 seconds
Summary 19 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 16.76411724090576 seconds
Summary 20 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 17.121089458465576 seconds
Summary 21 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 17.23695945739746 seconds
Summary 22 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.043478260869565216
Elapsed Time: 17.109238147735596 seconds
Summary 23 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.1
Elapsed Time: 23.85978865623474 seconds
Summary 24 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.07142857142857142
Elapsed Time: 22.2849862575531 seconds
Summary 25 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.125
Elapsed Time: 16.91410803794861 seconds
Summary 26 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.11764705882352941
Elapsed Time: 17.520513772964478 seconds
Summary 27 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.28
Elapsed Time: 18.59769606590271 seconds
Summary 28 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.045454545454545456
Elapsed Time: 16.348493576049805 seconds
Summary 29 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.18181818181818182
Elapsed Time: 16.924741983413696 seconds
Summary 30 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 16.893242597579956 seconds
Summary 31 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 17.26734495162964 seconds
Summary 0 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.15789473684210525
Elapsed Time: 17.016314268112183 seconds
Summary 1 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.12
Elapsed Time: 18.084744691848755 seconds
Summary 2 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


-0.1111111111111111
Elapsed Time: 18.33009648323059 seconds
Summary 3 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 20.499070644378662 seconds
Summary 4 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


0.0
Elapsed Time: 34.80295991897583 seconds
Summary 5 of batch
0
1
2
3
4
5
6
7
8
9




  masked_sentence = torch.tensor(masked_sentence).view(1, -1).long()


: 