In [1]:
import transformers
import datasets
import torch
import logging

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Preprocessing
sep_token = '<sep>'
special_token = '<ANSWER>' # between context and answer
dataset_name = "squad"
models_dir = "saved_models/bart_base_answer-aware_squad"
checkpoint = "facebook/bart-base"
max_input_length = 768
max_target_length = 128
cuda_device = torch.device("cuda:0")

## Training
learning_rate = 1e-4
num_epochs = 3

In [3]:
dataset = datasets.load_dataset(dataset_name)

Found cached dataset squad (C:/Users/manuv/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)
100%|██████████| 2/2 [00:00<00:00, 199.96it/s]


In [4]:
print(len(dataset["train"]))

87599


In [5]:
dataset["train"][0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}

In [6]:
model = transformers.BartForConditionalGeneration.from_pretrained(checkpoint).to(cuda_device)
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)

In [7]:
tokenizer.sep_token = sep_token
tokenizer.add_tokens([sep_token])
tokenizer.add_tokens([special_token])
tokenizer.special_tokens_map.update({special_token: '[unused100]'})
model.resize_token_embeddings(len(tokenizer))

Embedding(50267, 768)

In [8]:
# Tokenize examples
def convert_to_features(example_batch):

    input_encodings = tokenizer.batch_encode_plus(example_batch['input'], 
                                                  max_length=max_input_length, 
                                                  add_special_tokens=True,
                                                  truncation=True, 
                                                  pad_to_max_length=True)
    
    target_encodings = tokenizer.batch_encode_plus(example_batch['question'], 
                                                   max_length=max_target_length, 
                                                   add_special_tokens=True,
                                                   truncation=True, pad_to_max_length=True)
                                                   
    encodings = {
        'input_ids': input_encodings['input_ids'], 
        'attention_mask': input_encodings['attention_mask'],
        'decoder_input_ids': target_encodings['input_ids'],
        'decoder_attention_mask': target_encodings['attention_mask']
    }

    return encodings

def add_eos_examples(example):
    example['input'] = example['context'] + " " + special_token + example["answers"]["text"][0] + " " + sep_token
    example['question'] = example['question'] + " " + sep_token
    return example


def add_special_tokens(example):
  example['question'] = example['question'].replace("{sep_token}", sep_token)
  return example

In [9]:
tokenized_dataset  = dataset.map(add_eos_examples)
tokenized_dataset = tokenized_dataset.map(add_special_tokens)
tokenized_dataset  = tokenized_dataset.map(convert_to_features,  batched=True)

Loading cached processed dataset at C:\Users\manuv\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453\cache-ae087ab63e8c4789.arrow
Loading cached processed dataset at C:\Users\manuv\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453\cache-74b418349551b109.arrow
Loading cached processed dataset at C:\Users\manuv\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453\cache-c6a7fde7ed727316.arrow
Loading cached processed dataset at C:\Users\manuv\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453\cache-3b93c6b9007bad4e.arrow
Loading cached processed dataset at C:\Users\manuv\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453\cache-0cbf9ea6ceeaac65.arrow
                    

In [10]:
tokenized_dataset["train"][0]["question"]

'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? <sep>'

In [11]:
tokenized_dataset["train"][0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? <sep>',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]},
 'input': 'Architecturally, the school has a

In [12]:

tokenized_dataset = tokenized_dataset.remove_columns(
    ["input","context","question","answers"]
)

train_dataset = tokenized_dataset["train"]
valid_dataset = tokenized_dataset["validation"]

columns = ['input_ids', 'decoder_input_ids', 'attention_mask', 'decoder_attention_mask']
train_dataset.set_format(type='torch', columns=columns)
valid_dataset.set_format(type='torch', columns=columns)

In [13]:
torch.save(train_dataset, 'train_data.pt')
torch.save(valid_dataset, 'valid_data.pt')

In [14]:
from typing import Dict, List

class T2TDataCollator():
    def __call__(self, batch: List) -> Dict[str, torch.Tensor]:
        """
        Take a list of samples from a Dataset and collate them into a batch.
        Returns:
        A dictionary of tensors
        """

        input_ids = torch.stack([example['input_ids'] for example in batch])
        lm_labels = torch.stack([example['decoder_input_ids'] for example in batch])
        lm_labels[lm_labels[:, :] == 0] = -100 
        attention_mask = torch.stack([example['attention_mask'] for example in batch])
        decoder_attention_mask = torch.stack([example['decoder_attention_mask'] for example in batch])

        return {
            'input_ids': input_ids, 
            'attention_mask': attention_mask,
            'labels': lm_labels, 
            'decoder_attention_mask': decoder_attention_mask
        }

In [15]:
training_args = transformers.TrainingArguments(output_dir=models_dir, 
                                  per_device_train_batch_size=4, 
                                  per_device_eval_batch_size=4,
                                  gradient_accumulation_steps=16,
                                  learning_rate=learning_rate, 
                                  num_train_epochs=num_epochs,
                                  logging_steps=100,
                                  run_name="bart_answer-aware_qg_squad",
                                  evaluation_strategy="steps",
                                  save_steps=500)

In [16]:
logger = logging.getLogger(__name__)

# Initialize our Trainer
trainer = transformers.Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=T2TDataCollator()
)


In [17]:
if model.device.type == 'cuda':
    print('Model is on GPU')
else:
    print('Model is on CPU')

Model is on GPU


In [18]:
trainer.train() 

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmanu-vleurick[0m ([33mhogent-team[0m). Use [1m`wandb login --relogin`[0m to force relogin


  2%|▏         | 100/4104 [07:54<5:15:05,  4.72s/it]

{'loss': 1.8227, 'learning_rate': 9.756335282651073e-05, 'epoch': 0.07}


                                                    
  2%|▏         | 100/4104 [12:09<5:15:05,  4.72s/it]

{'eval_loss': 0.19733433425426483, 'eval_runtime': 254.7636, 'eval_samples_per_second': 41.489, 'eval_steps_per_second': 10.374, 'epoch': 0.07}


  5%|▍         | 200/4104 [20:02<5:07:56,  4.73s/it] 

{'loss': 0.2148, 'learning_rate': 9.512670565302145e-05, 'epoch': 0.15}


                                                    
  5%|▍         | 200/4104 [24:18<5:07:56,  4.73s/it]

{'eval_loss': 0.18540476262569427, 'eval_runtime': 255.3556, 'eval_samples_per_second': 41.393, 'eval_steps_per_second': 10.35, 'epoch': 0.15}


  7%|▋         | 300/4104 [32:11<5:00:10,  4.73s/it] 

{'loss': 0.2062, 'learning_rate': 9.269005847953217e-05, 'epoch': 0.22}


                                                    
  7%|▋         | 300/4104 [36:27<5:00:10,  4.73s/it]

{'eval_loss': 0.18565750122070312, 'eval_runtime': 255.4533, 'eval_samples_per_second': 41.377, 'eval_steps_per_second': 10.346, 'epoch': 0.22}


 10%|▉         | 400/4104 [44:20<4:52:27,  4.74s/it] 

{'loss': 0.2017, 'learning_rate': 9.025341130604289e-05, 'epoch': 0.29}


                                                    
 10%|▉         | 400/4104 [48:36<4:52:27,  4.74s/it]

{'eval_loss': 0.18062256276607513, 'eval_runtime': 255.5192, 'eval_samples_per_second': 41.367, 'eval_steps_per_second': 10.344, 'epoch': 0.29}


 12%|█▏        | 500/4104 [56:30<4:44:36,  4.74s/it] 

{'loss': 0.195, 'learning_rate': 8.78167641325536e-05, 'epoch': 0.37}


                                                    
 12%|█▏        | 500/4104 [1:00:45<4:44:36,  4.74s/it]

{'eval_loss': 0.179555743932724, 'eval_runtime': 255.5111, 'eval_samples_per_second': 41.368, 'eval_steps_per_second': 10.344, 'epoch': 0.37}


 15%|█▍        | 600/4104 [1:09:05<4:36:35,  4.74s/it] 

{'loss': 0.1935, 'learning_rate': 8.538011695906433e-05, 'epoch': 0.44}


                                                      
 15%|█▍        | 600/4104 [1:13:20<4:36:35,  4.74s/it]

{'eval_loss': 0.17684946954250336, 'eval_runtime': 255.5541, 'eval_samples_per_second': 41.361, 'eval_steps_per_second': 10.342, 'epoch': 0.44}


 17%|█▋        | 700/4104 [1:21:14<4:28:34,  4.73s/it] 

{'loss': 0.1886, 'learning_rate': 8.294346978557506e-05, 'epoch': 0.51}


                                                      
 17%|█▋        | 700/4104 [1:25:30<4:28:34,  4.73s/it]

{'eval_loss': 0.1766921877861023, 'eval_runtime': 255.6506, 'eval_samples_per_second': 41.345, 'eval_steps_per_second': 10.338, 'epoch': 0.51}


 19%|█▉        | 800/4104 [1:33:24<4:20:49,  4.74s/it] 

{'loss': 0.1844, 'learning_rate': 8.050682261208578e-05, 'epoch': 0.58}


                                                      
 19%|█▉        | 800/4104 [1:37:39<4:20:49,  4.74s/it]

{'eval_loss': 0.17447958886623383, 'eval_runtime': 255.3582, 'eval_samples_per_second': 41.393, 'eval_steps_per_second': 10.35, 'epoch': 0.58}


 22%|██▏       | 900/4104 [1:45:33<4:12:50,  4.73s/it] 

{'loss': 0.1875, 'learning_rate': 7.807017543859649e-05, 'epoch': 0.66}


                                                      
 22%|██▏       | 900/4104 [1:49:48<4:12:50,  4.73s/it]

{'eval_loss': 0.17443932592868805, 'eval_runtime': 255.5861, 'eval_samples_per_second': 41.356, 'eval_steps_per_second': 10.341, 'epoch': 0.66}


 24%|██▍       | 1000/4104 [1:57:42<4:04:59,  4.74s/it]

{'loss': 0.1835, 'learning_rate': 7.563352826510721e-05, 'epoch': 0.73}


                                                       
 24%|██▍       | 1000/4104 [2:01:58<4:04:59,  4.74s/it]

{'eval_loss': 0.17313124239444733, 'eval_runtime': 255.5466, 'eval_samples_per_second': 41.362, 'eval_steps_per_second': 10.343, 'epoch': 0.73}


 27%|██▋       | 1100/4104 [2:09:54<3:55:21,  4.70s/it] 

{'loss': 0.1838, 'learning_rate': 7.319688109161794e-05, 'epoch': 0.8}


                                                       
 27%|██▋       | 1100/4104 [2:14:08<3:55:21,  4.70s/it]

{'eval_loss': 0.17158250510692596, 'eval_runtime': 253.5777, 'eval_samples_per_second': 41.683, 'eval_steps_per_second': 10.423, 'epoch': 0.8}


 29%|██▉       | 1200/4104 [2:21:59<3:47:32,  4.70s/it] 

{'loss': 0.1819, 'learning_rate': 7.076023391812866e-05, 'epoch': 0.88}


                                                       
 29%|██▉       | 1200/4104 [2:26:12<3:47:32,  4.70s/it]

{'eval_loss': 0.1711895614862442, 'eval_runtime': 253.6594, 'eval_samples_per_second': 41.67, 'eval_steps_per_second': 10.419, 'epoch': 0.88}


 32%|███▏      | 1300/4104 [2:34:02<3:39:38,  4.70s/it] 

{'loss': 0.1805, 'learning_rate': 6.832358674463938e-05, 'epoch': 0.95}


                                                       
 32%|███▏      | 1300/4104 [2:38:16<3:39:38,  4.70s/it]

{'eval_loss': 0.17114779353141785, 'eval_runtime': 253.5821, 'eval_samples_per_second': 41.683, 'eval_steps_per_second': 10.423, 'epoch': 0.95}


 34%|███▍      | 1400/4104 [2:46:06<3:31:48,  4.70s/it] 

{'loss': 0.1731, 'learning_rate': 6.58869395711501e-05, 'epoch': 1.02}


                                                       
 34%|███▍      | 1400/4104 [2:50:20<3:31:48,  4.70s/it]

{'eval_loss': 0.17392054200172424, 'eval_runtime': 253.6522, 'eval_samples_per_second': 41.671, 'eval_steps_per_second': 10.42, 'epoch': 1.02}


 37%|███▋      | 1500/4104 [2:58:10<3:23:58,  4.70s/it] 

{'loss': 0.1545, 'learning_rate': 6.345029239766082e-05, 'epoch': 1.1}


                                                       
 37%|███▋      | 1500/4104 [3:02:24<3:23:58,  4.70s/it]

{'eval_loss': 0.17261776328086853, 'eval_runtime': 253.649, 'eval_samples_per_second': 41.672, 'eval_steps_per_second': 10.42, 'epoch': 1.1}


 39%|███▉      | 1600/4104 [3:10:18<3:16:07,  4.70s/it] 

{'loss': 0.1545, 'learning_rate': 6.101364522417154e-05, 'epoch': 1.17}


                                                       
 39%|███▉      | 1600/4104 [3:14:32<3:16:07,  4.70s/it]

{'eval_loss': 0.17098908126354218, 'eval_runtime': 253.5656, 'eval_samples_per_second': 41.685, 'eval_steps_per_second': 10.423, 'epoch': 1.17}


 41%|████▏     | 1700/4104 [3:22:22<3:08:18,  4.70s/it] 

{'loss': 0.1521, 'learning_rate': 5.8576998050682263e-05, 'epoch': 1.24}


                                                       
 41%|████▏     | 1700/4104 [3:26:36<3:08:18,  4.70s/it]

{'eval_loss': 0.1710432469844818, 'eval_runtime': 253.5557, 'eval_samples_per_second': 41.687, 'eval_steps_per_second': 10.424, 'epoch': 1.24}


 44%|████▍     | 1800/4104 [3:34:26<3:00:28,  4.70s/it] 

{'loss': 0.1521, 'learning_rate': 5.6140350877192984e-05, 'epoch': 1.32}


                                                       
 44%|████▍     | 1800/4104 [3:38:39<3:00:28,  4.70s/it]

{'eval_loss': 0.17097194492816925, 'eval_runtime': 253.514, 'eval_samples_per_second': 41.694, 'eval_steps_per_second': 10.425, 'epoch': 1.32}


 46%|████▋     | 1900/4104 [3:46:29<2:52:38,  4.70s/it] 

{'loss': 0.1538, 'learning_rate': 5.370370370370371e-05, 'epoch': 1.39}


                                                       
 46%|████▋     | 1900/4104 [3:50:43<2:52:38,  4.70s/it]

{'eval_loss': 0.1710260957479477, 'eval_runtime': 253.4376, 'eval_samples_per_second': 41.707, 'eval_steps_per_second': 10.429, 'epoch': 1.39}


 49%|████▊     | 2000/4104 [3:58:33<2:44:46,  4.70s/it] 

{'loss': 0.1523, 'learning_rate': 5.126705653021443e-05, 'epoch': 1.46}


                                                       
 49%|████▊     | 2000/4104 [4:02:46<2:44:46,  4.70s/it]

{'eval_loss': 0.1703895926475525, 'eval_runtime': 253.5507, 'eval_samples_per_second': 41.688, 'eval_steps_per_second': 10.424, 'epoch': 1.46}


 51%|█████     | 2100/4104 [4:11:05<2:36:57,  4.70s/it] 

{'loss': 0.1534, 'learning_rate': 4.883040935672515e-05, 'epoch': 1.53}


                                                       
 51%|█████     | 2100/4104 [4:15:18<2:36:57,  4.70s/it]

{'eval_loss': 0.17019546031951904, 'eval_runtime': 253.5747, 'eval_samples_per_second': 41.684, 'eval_steps_per_second': 10.423, 'epoch': 1.53}


 54%|█████▎    | 2200/4104 [4:23:09<2:29:09,  4.70s/it] 

{'loss': 0.1524, 'learning_rate': 4.6393762183235865e-05, 'epoch': 1.61}


                                                       
 54%|█████▎    | 2200/4104 [4:27:22<2:29:09,  4.70s/it]

{'eval_loss': 0.16884207725524902, 'eval_runtime': 253.7417, 'eval_samples_per_second': 41.657, 'eval_steps_per_second': 10.416, 'epoch': 1.61}


 56%|█████▌    | 2300/4104 [4:35:16<2:22:23,  4.74s/it] 

{'loss': 0.1507, 'learning_rate': 4.395711500974659e-05, 'epoch': 1.68}


                                                       
 56%|█████▌    | 2300/4104 [4:39:32<2:22:23,  4.74s/it]

{'eval_loss': 0.16891217231750488, 'eval_runtime': 255.5256, 'eval_samples_per_second': 41.366, 'eval_steps_per_second': 10.343, 'epoch': 1.68}


 58%|█████▊    | 2400/4104 [4:47:25<2:14:31,  4.74s/it] 

{'loss': 0.1493, 'learning_rate': 4.152046783625731e-05, 'epoch': 1.75}


                                                       
 58%|█████▊    | 2400/4104 [4:51:41<2:14:31,  4.74s/it]

{'eval_loss': 0.16923968493938446, 'eval_runtime': 255.5589, 'eval_samples_per_second': 41.36, 'eval_steps_per_second': 10.342, 'epoch': 1.75}


 61%|██████    | 2500/4104 [4:59:35<2:06:38,  4.74s/it] 

{'loss': 0.1528, 'learning_rate': 3.908382066276803e-05, 'epoch': 1.83}


                                                       
 61%|██████    | 2500/4104 [5:03:50<2:06:38,  4.74s/it]

{'eval_loss': 0.16780829429626465, 'eval_runtime': 255.5524, 'eval_samples_per_second': 41.361, 'eval_steps_per_second': 10.342, 'epoch': 1.83}


 63%|██████▎   | 2600/4104 [5:12:16<1:58:46,  4.74s/it] 

{'loss': 0.1509, 'learning_rate': 3.664717348927875e-05, 'epoch': 1.9}


                                                       
 63%|██████▎   | 2600/4104 [5:16:32<1:58:46,  4.74s/it]

{'eval_loss': 0.16810578107833862, 'eval_runtime': 255.6167, 'eval_samples_per_second': 41.351, 'eval_steps_per_second': 10.34, 'epoch': 1.9}


 66%|██████▌   | 2700/4104 [5:24:25<1:50:54,  4.74s/it] 

{'loss': 0.1472, 'learning_rate': 3.421052631578947e-05, 'epoch': 1.97}


                                                       
 66%|██████▌   | 2700/4104 [5:28:41<1:50:54,  4.74s/it]

{'eval_loss': 0.16708733141422272, 'eval_runtime': 255.6839, 'eval_samples_per_second': 41.34, 'eval_steps_per_second': 10.337, 'epoch': 1.97}


 68%|██████▊   | 2800/4104 [5:36:36<1:43:00,  4.74s/it] 

{'loss': 0.1383, 'learning_rate': 3.1773879142300193e-05, 'epoch': 2.05}


                                                       
 68%|██████▊   | 2800/4104 [5:40:51<1:43:00,  4.74s/it]

{'eval_loss': 0.16973377764225006, 'eval_runtime': 255.5691, 'eval_samples_per_second': 41.359, 'eval_steps_per_second': 10.342, 'epoch': 2.05}


 71%|███████   | 2900/4104 [5:48:45<1:35:00,  4.74s/it] 

{'loss': 0.1288, 'learning_rate': 2.9337231968810917e-05, 'epoch': 2.12}


                                                       
 71%|███████   | 2900/4104 [5:53:01<1:35:00,  4.74s/it]

{'eval_loss': 0.17079736292362213, 'eval_runtime': 255.5591, 'eval_samples_per_second': 41.36, 'eval_steps_per_second': 10.342, 'epoch': 2.12}


 73%|███████▎  | 3000/4104 [6:00:54<1:27:08,  4.74s/it] 

{'loss': 0.1313, 'learning_rate': 2.6900584795321637e-05, 'epoch': 2.19}


                                                       
 73%|███████▎  | 3000/4104 [6:05:10<1:27:08,  4.74s/it]

{'eval_loss': 0.16982384026050568, 'eval_runtime': 255.5619, 'eval_samples_per_second': 41.36, 'eval_steps_per_second': 10.342, 'epoch': 2.19}


 76%|███████▌  | 3100/4104 [6:13:35<1:19:13,  4.73s/it] 

{'loss': 0.1303, 'learning_rate': 2.4463937621832358e-05, 'epoch': 2.26}


                                                       
 76%|███████▌  | 3100/4104 [6:17:50<1:19:13,  4.73s/it]

{'eval_loss': 0.16963498294353485, 'eval_runtime': 255.4539, 'eval_samples_per_second': 41.377, 'eval_steps_per_second': 10.346, 'epoch': 2.26}


 78%|███████▊  | 3200/4104 [6:25:44<1:11:24,  4.74s/it] 

{'loss': 0.1306, 'learning_rate': 2.2027290448343078e-05, 'epoch': 2.34}


                                                       
 78%|███████▊  | 3200/4104 [6:30:00<1:11:24,  4.74s/it]

{'eval_loss': 0.1694660484790802, 'eval_runtime': 255.4814, 'eval_samples_per_second': 41.373, 'eval_steps_per_second': 10.345, 'epoch': 2.34}


 80%|████████  | 3300/4104 [6:37:53<1:03:27,  4.74s/it] 

{'loss': 0.1305, 'learning_rate': 1.9590643274853802e-05, 'epoch': 2.41}


                                                       
 80%|████████  | 3300/4104 [6:42:09<1:03:27,  4.74s/it]

{'eval_loss': 0.16849184036254883, 'eval_runtime': 255.4703, 'eval_samples_per_second': 41.375, 'eval_steps_per_second': 10.346, 'epoch': 2.41}


 83%|████████▎ | 3400/4104 [6:50:02<55:34,  4.74s/it]   

{'loss': 0.13, 'learning_rate': 1.7153996101364522e-05, 'epoch': 2.48}


                                                     
 83%|████████▎ | 3400/4104 [6:54:18<55:34,  4.74s/it]

{'eval_loss': 0.16920869052410126, 'eval_runtime': 255.4571, 'eval_samples_per_second': 41.377, 'eval_steps_per_second': 10.346, 'epoch': 2.48}


 85%|████████▌ | 3500/4104 [7:02:11<47:40,  4.74s/it]   

{'loss': 0.131, 'learning_rate': 1.4717348927875244e-05, 'epoch': 2.56}


                                                     
 85%|████████▌ | 3500/4104 [7:06:27<47:40,  4.74s/it]

{'eval_loss': 0.16852499544620514, 'eval_runtime': 255.4956, 'eval_samples_per_second': 41.371, 'eval_steps_per_second': 10.345, 'epoch': 2.56}


 88%|████████▊ | 3600/4104 [7:14:50<40:36,  4.83s/it]   

{'loss': 0.129, 'learning_rate': 1.2280701754385964e-05, 'epoch': 2.63}


                                                     
 88%|████████▊ | 3600/4104 [7:19:06<40:36,  4.83s/it]

{'eval_loss': 0.16798582673072815, 'eval_runtime': 255.4221, 'eval_samples_per_second': 41.382, 'eval_steps_per_second': 10.348, 'epoch': 2.63}


 90%|█████████ | 3700/4104 [7:26:59<31:50,  4.73s/it]   

{'loss': 0.13, 'learning_rate': 9.844054580896686e-06, 'epoch': 2.7}


                                                     
 90%|█████████ | 3700/4104 [7:31:14<31:50,  4.73s/it]

{'eval_loss': 0.1686238944530487, 'eval_runtime': 255.286, 'eval_samples_per_second': 41.405, 'eval_steps_per_second': 10.353, 'epoch': 2.7}


 93%|█████████▎| 3800/4104 [7:39:10<24:29,  4.83s/it]  

{'loss': 0.1298, 'learning_rate': 7.4074074074074075e-06, 'epoch': 2.78}


                                                     
 93%|█████████▎| 3800/4104 [7:43:25<24:29,  4.83s/it]

{'eval_loss': 0.16804394125938416, 'eval_runtime': 255.4211, 'eval_samples_per_second': 41.383, 'eval_steps_per_second': 10.348, 'epoch': 2.78}


 95%|█████████▌| 3900/4104 [7:51:18<16:05,  4.73s/it]  

{'loss': 0.1291, 'learning_rate': 4.970760233918129e-06, 'epoch': 2.85}


                                                     
 95%|█████████▌| 3900/4104 [7:55:34<16:05,  4.73s/it]

{'eval_loss': 0.16799214482307434, 'eval_runtime': 255.292, 'eval_samples_per_second': 41.404, 'eval_steps_per_second': 10.353, 'epoch': 2.85}


 97%|█████████▋| 4000/4104 [8:03:27<08:11,  4.73s/it]  

{'loss': 0.1289, 'learning_rate': 2.5341130604288498e-06, 'epoch': 2.92}


                                                     
 97%|█████████▋| 4000/4104 [8:07:42<08:11,  4.73s/it]

{'eval_loss': 0.16766060888767242, 'eval_runtime': 255.4496, 'eval_samples_per_second': 41.378, 'eval_steps_per_second': 10.346, 'epoch': 2.92}


100%|█████████▉| 4100/4104 [8:16:13<00:18,  4.73s/it]  

{'loss': 0.1293, 'learning_rate': 9.746588693957116e-08, 'epoch': 3.0}


                                                     
100%|█████████▉| 4100/4104 [8:20:29<00:18,  4.73s/it]

{'eval_loss': 0.16763707995414734, 'eval_runtime': 255.371, 'eval_samples_per_second': 41.391, 'eval_steps_per_second': 10.35, 'epoch': 3.0}


100%|██████████| 4104/4104 [8:20:48<00:00,  7.32s/it]

{'train_runtime': 30051.5237, 'train_samples_per_second': 8.745, 'train_steps_per_second': 0.137, 'train_loss': 0.19749631400112985, 'epoch': 3.0}





TrainOutput(global_step=4104, training_loss=0.19749631400112985, metrics={'train_runtime': 30051.5237, 'train_samples_per_second': 8.745, 'train_steps_per_second': 0.137, 'train_loss': 0.19749631400112985, 'epoch': 3.0})

In [19]:
# save the model
trainer.save_model(models_dir)

In [20]:

from transformers import BartForConditionalGeneration, AutoTokenizer
# load the saved model
loaded_model = BartForConditionalGeneration.from_pretrained(models_dir)

In [21]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [25]:
loaded_model.to("cuda:0")

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50267, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50267, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=

In [26]:

def run_model(input_string,target_answer, **generator_args):
    generator_args = {
    "max_length": 768,
    "num_beams": 4,# bij grotere num_beams is trager maar complexere vragen(niet per se betere vragen)
    "length_penalty": 1.5,
    "no_repeat_ngram_size": 3,
    "early_stopping": True,
    }
    
    input_string = input_string + " " + "<ANSWER>" + target_answer + " </s>"
    input_ids = tokenizer.encode(input_string, return_tensors="pt")
    res = loaded_model.generate(input_ids.to(cuda_device), **generator_args)
    output = tokenizer.batch_decode(res, skip_special_tokens=True)
    output = [item.split("<sep>") for item in output]
    return output


In [27]:
context = """
Cheese is an ancient food whose origins predate recorded history. There is no conclusive evidence indicating where cheesemaking originated, whether in Europe, Central Asia or the Middle East. Earliest proposed dates for the origin of cheesemaking range from around 8000 BCE, when sheep were first domesticated. Since animal skins and inflated internal organs have, since ancient times, provided storage vessels for a range of foodstuffs, it is probable that the process of cheese making was discovered accidentally by storing milk in a container made from the stomach of an animal, resulting in the milk being turned to curd and whey by the rennet from the stomach.[7] There is a legend—with variations—about the discovery of cheese by an Arab trader who used this method of storing milk.[8]

The earliest evidence of cheesemaking in the archaeological record dates back to 5500 BCE and is found in what is now Kuyavia, Poland, where strainers coated with milk-fat molecules have been found.[9]

Cheesemaking may have begun independently of this by the pressing and salting of curdled milk to preserve it. Observation that the effect of making cheese in an animal stomach gave more solid and better-textured curds may have led to the deliberate addition of rennet. Early archeological evidence of Egyptian cheese has been found in Egyptian tomb murals, dating to about 2000 BCE.[10] A 2018 scientific paper stated that the world's oldest cheese, dating to approximately 1200 BCE (3200 years before present), was found in ancient Egyptian tombs.[11][12]

The earliest cheeses were likely quite sour and salty, similar in texture to rustic cottage cheese or feta, a crumbly, flavorful Greek cheese. Cheese produced in Europe, where climates are cooler than the Middle East, required less salt for preservation. With less salt and acidity, the cheese became a suitable environment for useful microbes and molds, giving aged cheeses their respective flavors. The earliest ever discovered preserved cheese was found in the Taklamakan Desert in Xinjiang, China, dating back as early as 1615 BCE (3600 years before present).
"""

context_gf = """
#### Types of Bias

**Selection bias** is the tendency to skew your choice of data sources to 
those that are easily available, convenient, and/or cost-effective. As a 
result of this a bias is introduced by the selection of individuals, groups 
or data for analysis in such a way that proper randomization is not achieved, 
thereby ensuring that the sample obtained is not representative of the 
population intended to be analyzed. 
[Learn more about selection bias](https://en.wikipedia.org/wiki/Selection_bias).

**Self-selection bias** is a form of selection bias where you get the data 
from sources that “volunteered” to provide it. Most poll data has this type  
of bias. [Learn more about self-selection bias](https://en.wikipedia.org/wiki/Self-selection_bias)

**Omitted-variable bias** happens when your featurized data doesn't have a 
feature necessary for accurate prediction. For example, let's assume that you 
are working on a churn prediction model and you want to predict whether a 
customer cancels their subscription within six months. You train a model, and 
it's accurate enough; however, several weeks after deployment you see many 
unexpected false negatives. You investigate the decreased model performance 
and discover a new competitor now offers a very similar service for a lower 
price. This feature wasn't initially available to your model, therefore 
important information for accurate prediction was missing. 
[Learn more about omitted variable Bias](https://en.wikipedia.org/wiki/Omitted-variable_bias).
"""

run_model(context,"Cheese")

[['What is an ancient food whose origins predate recorded history? ']]

In [None]:
# final metrics
trainer.evaluate()

100%|██████████| 2643/2643 [02:59<00:00, 14.71it/s]


{'eval_loss': 0.19722864031791687,
 'eval_runtime': 179.9317,
 'eval_samples_per_second': 58.744,
 'eval_steps_per_second': 14.689,
 'epoch': 5.0}

## Evaluation

In [1]:
import numpy as np
import nltk
from rouge_score import rouge_scorer
from nltk.translate.meteor_score import meteor_score as calculate_meteor
from nltk.translate.bleu_score import SmoothingFunction
import datasets
from transformers import BartForConditionalGeneration, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset_name = "squad"
models_dir = "saved_models/bart_base_answer-aware_squad"
checkpoint = "facebook/bart-base"
val_perc = 60
val_dataset = datasets.load_dataset(dataset_name, split=f'validation[:{val_perc}%]')
test_dataset = datasets.load_dataset(dataset_name,split=f"validation[{val_perc}%:]")

loaded_model = BartForConditionalGeneration.from_pretrained(models_dir).to("cuda:0")
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

Found cached dataset squad (C:/Users/manuv/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)
Found cached dataset squad (C:/Users/manuv/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


In [3]:

def run_model(input_string,target_answer, **generator_args):
    generator_args = {
    "max_length": 768,
    "num_beams": 4,# bij grotere num_beams is trager maar complexere vragen(niet per se betere vragen)
    "length_penalty": 1.5,
    "no_repeat_ngram_size": 3,
    "early_stopping": True,
    }
    
    input_string = input_string + " " + "<ANSWER>" + target_answer + " </s>"
    input_ids = tokenizer.encode(input_string, return_tensors="pt")
    res = loaded_model.generate(input_ids.to("cuda:0"), **generator_args)
    output = tokenizer.batch_decode(res, skip_special_tokens=True)
    output = [item.split("<sep>") for item in output]
    return output


### Automatic Metrics

In [4]:
# 1.8s per row 
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
smoother = SmoothingFunction().method1

bleu_scores = np.array([])
rouge_scores = np.array([])
meteor_scores = np.array([])
count = 0
for row in test_dataset:
    context,target_question,answer = row["context"],row["question"],','.join(set(row["answers"]["text"]))

    gen_question = run_model(context,answer)[0][0].strip()
    print(f'gen_question: {gen_question} <-> target_question: {target_question} = {answer} | {count}')

    # Tokenization
    gen_tokens = nltk.word_tokenize(gen_question.lower())
    target_tokens = nltk.word_tokenize(target_question.lower())

    # Calculate scores
    bleu_score = nltk.translate.bleu_score.sentence_bleu([target_tokens], gen_tokens,smoothing_function=smoother)
    bleu_scores = np.append(bleu_scores,bleu_score)

    # Calculate ROUGE score
    rouge_score = scorer.score(target_question, gen_question)# Doesn't need tokenization
    rouge_scores = np.append(rouge_scores,rouge_score)

    # Calculate METEOR score
    meteor_score = calculate_meteor([target_tokens],gen_tokens)
    
    meteor_scores = np.append(meteor_scores,meteor_score)

    count+=1
 

rouge_scores = [score["rougeL"].fmeasure for score in rouge_scores]# Longest common subsequence and the f-score

avg_bleu_score = np.mean(bleu_scores)
avg_rouge_score = np.mean(rouge_scores)
avg_meteor_score = np.mean(meteor_scores)

print(f'Average BLEU4: {avg_bleu_score}')
print(f'Average ROUGE: {avg_rouge_score}')
print(f'Average METEOR: {avg_meteor_score}')

gen_question: What would the 13th century Mongolian pronunciation have closely matched? <-> target_question: What spelling of Genghis most closely matches its probable pronunciation? = "Jenggis,Chinggis | 0
gen_question: What is the Mongolian name for Genghis Khan? <-> target_question: What is the Mongolian spelling of Genghis Khan? = Chinggis Khaan | 1
gen_question: What are some of the names of Genghis Khan's titles? <-> target_question: How is Genghis Khan spelled in Turkic? = Cengiz Han, Çingiz Xan, Çingiz Han, Chingizxon, Çıñğız Xan, Chengez Khan, Chinggis Khan, Chinggis Xaan, Chingis Khan, Jenghis Khan, Chinggis Qan, Djingis Kahn,Cengiz Han | 2
gen_question: What is the pinyin for the title of Genghis Khan? <-> target_question: How is Temüjin written in pinyin? = Tiěmùzhēn | 3
gen_question: What are some languages in which the title of Genghis Khan can be spelled? <-> target_question: What are alternate English spelling of Genghis? = Chinghiz, Chinghis, and Chingiz | 4
gen_questi

### QA Evaluation

In [5]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from transformers import pipeline

tokenizer_qa = AutoTokenizer.from_pretrained("damapika/roberta-base_mod")
model_qa = AutoModelForQuestionAnswering.from_pretrained("damapika/roberta-base_mod")

dataset_name = "squad"
val_perc = 60
val_dataset = datasets.load_dataset(dataset_name, split=f'validation[:{val_perc}%]')
test_dataset = datasets.load_dataset(dataset_name,split=f"validation[{val_perc}%:]")



In [6]:
question_answerer = pipeline("question-answering", model="damapika/roberta-base_mod")

tp = 0

count = 0
for row in test_dataset:
    context,target_question,answer = row["context"],row["question"],','.join(set(row["answers"]["text"]))

    gen_question = run_model(context,answer)[0][0]

    gen_answer, qa_score = question_answerer(question=gen_question, context=context)['answer'],question_answerer(question=gen_question, context=context)['score']
    print(f'gen_question: {gen_question} <-> target_question: {target_question} = target_answer: {answer} gen_answer: {gen_answer} | {count}')

    # Accuracy and F1 score
    if gen_answer in answer:
        tp += 1
    
    count+=1


print(f'QA Accuracy Score: {tp/count}')

gen_question: What would the 13th century Mongolian pronunciation have closely matched?  <-> target_question: What spelling of Genghis most closely matches its probable pronunciation? = target_answer: "Jenggis,Chinggis gen_answer: Chinggis | 0
gen_question: What is the Mongolian name for Genghis Khan?  <-> target_question: What is the Mongolian spelling of Genghis Khan? = target_answer: Chinggis Khaan gen_answer: Chinggis Khaan | 1
gen_question: What are some of the names of Genghis Khan's titles?  <-> target_question: How is Genghis Khan spelled in Turkic? = target_answer: Cengiz Han, Çingiz Xan, Çingiz Han, Chingizxon, Çıñğız Xan, Chengez Khan, Chinggis Khan, Chinggis Xaan, Chingis Khan, Jenghis Khan, Chinggis Qan, Djingis Kahn,Cengiz Han gen_answer: Chinghiz, Chinghis, and Chingiz | 2
gen_question: What is the pinyin for the title of Genghis Khan?  <-> target_question: How is Temüjin written in pinyin? = target_answer: Tiěmùzhēn gen_answer: Chéngjísī Hán | 3
gen_question: What are s