# BERT Training script

In order to train a BERT model, we first need to generate positive and negative samples. To make things more realistic, we will first retrieve a top-10 results for each trainint query using Anserini, and then, randomly pick a few that are not relevant (2) as "negative sampling".

In [1]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64" # Make sure this is the sama JAVA_HOME as the installed version on the previous notebook!
data_home = "/ssd2/arthur/MsMarcoTREC/"
def path(x):
    return os.path.join(data_home, x)

try:
    import pyserini
except:
    !pip install pyserini==0.8.1.0 # install pyserini
try:
    import tqdm
except:
    !pip install tqdm # Good for progress bars!

In [2]:
import jnius_config
jnius_config.add_options('-Xmx16G') # Adjust to your machine. Probably less than 16G.
from pyserini.search import pysearch
import subprocess
from tqdm.auto import tqdm
import random
import pickle
import sys
import unicodedata
import string
import re
import os
from collections import defaultdict
import math

In [3]:
# Anserini uses this "SimpleSearcher" object for interfacing with the index.
index_path = path("lucene-index.msmarco-doc.pos+docvectors+rawdocs")
searcher = pysearch.SimpleSearcher(index_path)

## Extracting Anserini top-10
We will use pyserini to retrieve the top-10 results using BM25. It doesn't need to be perfect, so, we won't care about fine-tunning it. Default settings should be enough

The way this works is by:
1. submiting each query as a new search on Anserini, with the `SimpleSearcher.search()` method
2. For each query, find $neg_samples$ negative samples from the top-$k$ results from BM25.
3. Store these and the positive samples in a list

obs.: Potentinaly, it would be faster to use Anserini's `batch_search()` method, since it works in multiple threads. However, the lack of feedback (i.e. How may queries have been processed already) and higher memory footprint could cause issues.

### Loading all relevant docs from the qrels file

In [4]:
# Load the relevant query-document pairs
relevant_docs = defaultdict(lambda:[])
for file in [path("qrels/msmarco-doctrain-qrels.tsv"), path("qrels/msmarco-docdev-qrels.tsv")]:
    for line in open(file):
        query_id, _, doc_id, rel = line.split()
        assert rel == "1"
        relevant_docs[query_id].append(doc_id)                            

### Get the top-10 using BM25 and create a training set based on this
Some notes:

- If it finds the `.pkl` file created in the end of the loop, it won't re-compute everything.
- Each query is sanitized before being submitted to Anserini. (lines 22-32)
- The code will "batch" a number of queries to be submitted at once to Anserini, and will run these in parallel. This is much faster than one at a time, and more efficient than all of the queries at once.
- We store the end results in a pickle file, that is a list with the triples `query_id, doc_id, label`. 
- Should take about 1.5h to finish.
- Each element in the output list is: `[query_id, document_id, label]` where `label` is `1` for relevant and `0` for non-relevant
 
### **PAY ATTENTION TO YOUR MACHINE**
this notebook was ran at DeepIR, with 56 threads and 128GB of memory. Make sure to pick a fair number of threads, and a batchsize that fits confortably on memory. BE MINDFULL OF FAIR USAGE OF THE MACHINE. Check if someone else is using the machine, and chose a fair number of threads/batch size. In this configuration, 42 threads and batch_size 10000, this took about 6 minutes to finish. YMMV.

In [5]:
pattern = re.compile('([^\s\w]|_)+')

anserini_top_10 = defaultdict(lambda:[])
searcher.set_bm25_similarity(0.9, 0.4)
pairs_per_split = defaultdict(lambda: [])
threads = 42 # Number of Threads to use when retrieving
k = 10       # Number of documents to retrieve 
neg_samples = 2 # Number of negatives samples to use
batch_size = 10000 # Batch size for each retrieval step on Anserini

query_texts = dict()
for split in ["train", "dev"]:
    file_path = path(f"queries/msmarco-doc{split}-queries.tsv")
    run_search=True
    if os.path.isfile(file_path):
        print(f"Already found file {file_path}. Cowardly refusing to run this again. Will only load querytexts.")
        pairs_per_split[split] = pickle.load(open(path(f"{split}_triples.pkl"), 'rb'))
        run_search = False
    number_of_queries = int(subprocess.run(f"wc -l {file_path}".split(), capture_output=True).stdout.split()[0])
    number_of_batches = math.ceil(number_of_queries/batch_size)
    pbar = tqdm(total=number_of_batches, desc="Retrieval batches")
    queries = []
    query_ids = []
    for idx, line in enumerate(open(file_path, encoding="utf-8")):
        query_id, query = line.strip().split("\t")
        query_ids.append(query_id)
        query = unicodedata.normalize("NFKD", query) # Force queries into UTF-8
        query = pattern.sub(' ',query) # Remove non-ascii characters. It clears up most of the issues we may find on the query datasets
        query_texts[query_id] = query
        if run_search is False:
            continue
        queries.append(query)
        if len(queries) == batch_size or idx == number_of_queries-1:
            results = searcher.batch_search(queries, query_ids, k=k, threads=threads)
            pbar.update()
            for query, query_id in zip(queries, query_ids):
                retrieved_docs_ids = [hit.docid for hit in results[query_id]]
                relevant_docs_for_query = relevant_docs[query_id]
                retrieved_non_relevant_documents = set(retrieved_docs_ids).difference(set(relevant_docs_for_query))
                  
                if len(retrieved_non_relevant_documents) < 2:
                    print(f"query {query} has less than 2 retrieved docs.")
                    continue
                random_negative_samples = random.sample(retrieved_non_relevant_documents, neg_samples)
                pairs_per_split[split] += [(query_id, doc_id, 1) for doc_id in relevant_docs_for_query]
                pairs_per_split[split] += [(query_id, doc_id, 0) for doc_id in random_negative_samples]
            queries = []
            query_ids = []
    pickle.dump(pairs_per_split[split], open(path(f"{split}_triples.pkl"), 'wb'))
    pbar.close()



Already found file /ssd2/arthur/MsMarcoTREC/queries/msmarco-doctrain-queries.tsv. Cowardly refusing to run this again. Will only load querytexts.


HBox(children=(FloatProgress(value=0.0, description='Retrieval batches', max=37.0, style=ProgressStyle(descrip…


Already found file /ssd2/arthur/MsMarcoTREC/queries/msmarco-docdev-queries.tsv. Cowardly refusing to run this again. Will only load querytexts.


HBox(children=(FloatProgress(value=0.0, description='Retrieval batches', max=1.0, style=ProgressStyle(descript…




## Dataset Creation

This dataset is too big to fit in memory. Therefore, it's a good idea to leave it in disk, and retrieve as needed.

To do so, we will create three files: 
- `msmarco_samples.txt`: file with every sample already tokenized and in the right format to be used as input to BERT.
- `msmarco_offset.pkl`: pickle file with a dictionary with the file address for each of these samples. Will make it WAY faster to retrieve data from disk. 
- `msmarco_index.pkl`: A pickle file with a dictionary mapping each sample id (`queryid_docid`) to it's numbered position on the previous file. This will enable us to find a sample by index, and not only ID.

In [6]:
from torch.utils.data import Dataset
import torch

# This is our main Dataset class.
class MsMarcoDataset(Dataset):
    def __init__(self,
                 samples,
                 tokenizer,
                 searcher,
                 split,
                 tokenizer_batch=8000):
        '''Initialize a Dataset object. 
        Arguments:
            samples: A list of samples. Each sample should be a tuple with (query_id, doc_id, <label>), where label is optional
            tokenizer: A tokenizer object from Hugging Face's Tokenizer lib. (need to implement encode_batch())
            searcher: A PySerini Simple Searcher object. Should implement the .doc() method
            split: A strong indicating if we are in a train, dev or test dataset.
            tokenizer_batch: How many samples to be tokenized at once by the tokenizer object.
            The biggest bottleneck is the searcher, not the tokenizer.
        '''
        self.searcher = searcher
        self.split = split
        # If we already have the data pre-computed, we shouldn't need to re-compute it.
        self.split = split
        if (os.path.isfile(path(f"{split}_msmarco_samples.tsv"))
                and os.path.isfile(path(f"{split}_msmarco_offset.pkl"))
                and os.path.isfile(path(f"{split}_msmarco_index.pkl"))):
            print("Already found every meaningful file. Cowardly refusing to re-compute.")
            self.samples_offset_dict = pickle.load(open(path(f"{split}_msmarco_offset.pkl"), 'rb'))
            self.index_dict = pickle.load(open(path(f"{split}_msmarco_index.pkl"), 'rb'))
            return
        self.tokenizer = tokenizer
        print("Loading and tokenizing dataset...")
        self.samples_offset_dict = dict()
        self.index_dict = dict()

        self.samples_file = open(path(f"{split}_msmarco_samples.tsv"),'w',encoding="utf-8")
        self.processed_samples = 0
        query_batch = []
        doc_batch = []
        sample_ids_batch = []
        labels_batch = []
        number_of_batches = math.ceil(len(samples) // tokenizer_batch)
        # A progress bar to display how far we are.
        batch_pbar = tqdm(total=number_of_batches, desc="Tokenizer batches")
        for i, sample in enumerate(samples):
            if split=="train" or split == "dev":
                label = sample[2]
                labels_batch.append(label)
            query_batch.append(query_texts[sample[0]])
            doc_batch.append(self._get_document_content_from_id(sample[1]))
            sample_ids_batch.append(f"{sample[0]}_{sample[1]}")
            #If we hit the number of samples for this batch OR this is the last sample
            if len(query_batch) == tokenizer_batch or i == len(samples) - 1:
                self._tokenize_and_dump_batch(doc_batch, query_batch, labels_batch, sample_ids_batch)
                batch_pbar.update()
                query_batch = []
                doc_batch = []
                sample_ids_batch = []
                if split == "train" or split == "dev":
                    labels_batch = []
        batch_pbar.close()
        # Dump files in disk, so we don't need to go over it again.
        self.samples_file.close()
        pickle.dump(self.index_dict, open(path(f"{self.split}_msmarco_index.pkl"), 'wb'))
        pickle.dump(self.samples_offset_dict, open(path(f"{self.split}_msmarco_offset.pkl"), 'wb'))

    def _tokenize_and_dump_batch(self, doc_batch, query_batch, labels_batch,
                                 sample_ids_batch):
        '''tokenizes and dumps the samples in the current batch
        It also store the positions from the current file into the samples_offset_dict.
        '''
        # Use the tokenizer object
        tokens = self.tokenizer.encode_batch(list(zip(query_batch, doc_batch)))
        for idx, (sample_id, token) in enumerate(zip(sample_ids_batch, tokens)):
            #BERT supports up to 512 tokens. If we have more than that, we need to remove some tokens from the document
            if len(token.ids) >= 512:
                token_ids = token.ids[:511]
                token_ids.append(tokenizer.token_to_id("[SEP]"))
                segment_ids = token.type_ids[:512]
            # With less tokens, we need to "pad" the vectors up to 512.
            else:
                padding = [0] * (512 - len(token.ids))
                token_ids = token.ids + padding
                segment_ids = token.type_ids + padding
            # How far in the file are we? This is where we need to go to find the documents later.
            file_location = self.samples_file.tell()
            # If we have labels
            if self.split=="train" or split == "dev":
                self.samples_file.write(f"{sample_id}\t{token_ids}\t{segment_ids}\t{labels_batch[idx]}\n")
            else:
                self.samples_file.write(f"{sample_id}\t{token_ids}\t{segment_ids}\n")
            self.samples_offset_dict[sample_id] = file_location
            self.index_dict[self.processed_samples] = sample_id
            self.processed_samples += 1

    def _get_document_content_from_id(self, doc_id):
        '''Get the raw text value from the doc_id
        There is probably an easier way to do that, but this works.
        '''
        doc_text = self.searcher.doc(doc_id).lucene_document().getField("raw").stringValue()
        return doc_text[7:-8]

    def __getitem__(self, idx):
        '''Returns a sample with index idx
        DistilBERT does not take into account segment_ids. (indicator if the token comes from the query or the document) 
        However, for the sake of completness, we are including it here, together with the attention mask
        position_ids, with the positional encoder, is not needed. It's created for you inside the model.
        '''
        if isinstance(idx, int):
            idx = self.index_dict[idx]
        with open(path(f"{self.split}_msmarco_samples.tsv"), 'r', encoding="utf-8") as inf:
            inf.seek(self.samples_offset_dict[idx])
            line = inf.readline().split("\t")
            try:
                sample_id = line[0]
                input_ids = eval(line[1])
                token_type_ids = eval(line[2])
                input_mask = [1] * 512
            except:
                print(line, idx)
                raise IndexError
            # If it's a training dataset, we also have a label tag.
            if split=="train" or split == "dev":
                label = int(line[3])
                return (torch.tensor(input_ids, dtype=torch.long),
                        torch.tensor(input_mask, dtype=torch.long),
                        torch.tensor(token_type_ids, dtype=torch.long),
                        torch.tensor([label], dtype=torch.long))
            return (torch.tensor(input_ids, dtype=torch.long),
                    torch.tensor(input_mask, dtype=torch.long),
                    torch.tensor(token_type_ids, dtype=torch.long))
    def __len__(self):
        return len(self.samples_offset_dict)

## Training script
For actually training our model, we need to do the following:
1. Create a DataLoader object for train and one for dev. This will help with batching and such.
2. Load a BERT pre-trained model. For this example, we are using DistilBert. Because it's smaller and faster.
    - For ease of use, we will use the `DistilBertForSequenceClassification` model. It's ready for computing whether two senteces are related.
    - Also note that, for this model, weirdly enough, $1$ is NOT RELEVANT and  $0$ is RELEVANT
    - Alternativelly, we can use the default `DistilBert` and extract the `[CLS]` token embedding and feed it to a shallow NN using PyTorch or even Sklearn and a linear regression.
3. Create a training loop that for every $X$ samples will check the results on the dev dataset.
4. Store breakpoints every $N$ steps

In [7]:
from transformers import DistilBertForSequenceClassification
from torch.utils.data import DataLoader
from tokenizers import BertWordPieceTokenizer

tokenizer = BertWordPieceTokenizer("/ssd2/arthur/bert-axioms/tokenizer/bert-base-uncased-vocab.txt", lowercase=True)

In [8]:
train_dataset = MsMarcoDataset(pairs_per_split["train"], tokenizer, searcher, split = "train")
dev_dataset = MsMarcoDataset(pairs_per_split["dev"], tokenizer, searcher, split = "dev")

Already found every meaningful file. Cowardly refusing to re-compute.
Already found every meaningful file. Cowardly refusing to re-compute.


We NEED to use GPUs for this. If you don't have access to some GPUs you can try Google Colab OR if you are a MSc from WIS, get in touch.

In [9]:
from transformers import AdamW, get_linear_schedule_with_warmup

# With these configurations, on DeepIR, it takes ~3h/batch to train, with ~2batches/s
GPUS_TO_USE = [2,4,5,6,7] # If you have multiple GPUs, pick the ones you want to use.
number_of_cpus = 24 # Number of CPUS to use when loading your dataset.
n_epochs = 2 # How may passes over the whole dataset to complete
weight_decay = 0.0 # Some papers define a weight decay, meaning, the weights on some layers will decay slower overtime. By default, we don't do this.
lr = 0.00005 # Learning rate for the fine-tunning.
warmup_proportion = 0.1 # Percentage of training steps to perform before we start to decrease the learning rate.
steps_to_print = 1000 # How many steps to wait before printing loss
steps_to_eval = 2000 # How many steps to wait before running an eval step

# This is our base model
try:
    del model
    torch.cuda.empty_cache() # Make sure we have a clean slate
except:
    pass
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

if torch.cuda.is_available():
    # Asssign the model to GPUs, specifying to use Data parallelism.
    model = torch.nn.DataParallel(model, device_ids=GPUS_TO_USE)
    # The main model should be on the first GPU
    device = torch.device(f"cuda:{GPUS_TO_USE[0]}") 
    model.to(device)
    # For a 1080Ti, 16 samples fit on a GPU confortably. So, the train batch size will be 16*the number of GPUS
    train_batch_size = len(GPUS_TO_USE) * 16
    print(f"running on {len(GPUS_TO_USE)} GPUS, on {train_batch_size}-sized batches")
else:
    print("Are you sure about it? We will try to run this in CPU, but it's a BAD idea...")
    device = torch.device("cpu")
    train_batch_size = 16
    model.to(device)

# A data loader is a nice device for generating batches for you easily.
# It receives any object that implementes __getitem__(self, idx) and __len__(self)
train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, num_workers=number_of_cpus,shuffle=True)
dev_data_loader = DataLoader(dev_dataset, batch_size=32, num_workers=number_of_cpus,shuffle=True)

#how many optimization steps to run, given the NUMBER OF BATCHES. (The len of the dataloader is the number of batches).
num_train_optimization_steps = len(train_data_loader) * n_epochs

#which layers will not have a linear weigth decay when training
no_decay = ['bias', 'LayerNorm.weight']

#all parameters to be optimized by our fine tunning.
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any( nd in n for nd in no_decay)], 'weight_decay': weight_decay},
    {'params': [p for n, p in model.named_parameters() if any( nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

#We use the AdamW optmizer here.
optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8) 

# How many steps to wait before we start to decrease the learning rate
warmup_steps = num_train_optimization_steps * warmup_proportion 
# A scheduler to take care of the above.
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps)
print(f"*********Total optmization steps: {num_train_optimization_steps}*********")

running on 5 GPUS, on 80-sized batches
*********Total optmization steps: 42042*********


In [10]:
import warnings
import numpy as np
import datetime

try:
    from sklearn.metrics import f1_score, average_precision_score, accuracy_score, roc_auc_score
except:
    !pip install sklearn
    from sklearn.metrics import f1_score, average_precision_score, accuracy_score, roc_auc_score

global_step = 0 # Number of steps performed so far
tr_loss = 0.0 # Training loss
model.zero_grad() # Initialize gradients to 0

for _ in tqdm(range(n_epochs), desc="Epochs"):
    for step, batch in tqdm(enumerate(train_data_loader), desc="Batches", total=len(train_data_loader)):
        model.train()
        # get the batch inpute
        inputs = {
            'input_ids': batch[0].to(device),
            'attention_mask': batch[1].to(device),
            'labels': batch[3].to(device)
        }
        # Run through the network.
        
        with warnings.catch_warnings():
            # There is a very annoying warning here when we are using multiple GPUS,
            # As described here: https://github.com/huggingface/transformers/issues/852.
            # We can safely ignore this.
            warnings.simplefilter("ignore")
            outputs = model(**inputs)
        loss = outputs[0]

        loss = loss.sum()/len(model.device_ids) # Average over all GPUS.
        # Clipping gradients. Avoud gradient explosion, if the gradient is too large.
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        # Backward pass on the network
        loss.backward()
        tr_loss += loss.item()
        # Run the optimizer with the gradients
        optimizer.step()
        scheduler.step()
        model.zero_grad()
        if step % steps_to_print == 0:
            # Logits is the actual output from the network. 
            # This is the probability of being relevant or not.
            # You can check its shape (Should be a vector sized 2) with logits.shape()
            logits = outputs[1]
            # Send the logits to the CPU and in numpy form. Easier to check what is going on.
            preds = logits.detach().cpu().numpy()
            
            # Bring the labels to CPU too.
            tqdm.write(f"Training loss: {loss.item()} Learning Ragte: {scheduler.get_last_lr()[0]}")
        global_step += 1
        
        # Run an evluation step over the eval dataset. Let's see how we are going.
        if global_step%steps_to_eval == 0:
            eval_loss = 0.0
            nb_eval_steps = 0
            preds = None
            out_label_ids = None
            for batch in tqdm(dev_data_loader, desc="Dev batch"):
                model.eval()
                with torch.no_grad(): # Avoid upgrading gradients here
                    inputs = {'input_ids': batch[0].to(device),
                      'attention_mask': batch[1].to(device),
                      'labels': batch[3].to(device)}
                    with warnings.catch_warnings():
                        warnings.simplefilter("ignore")
                        outputs = model(**inputs)
                    tmp_eval_loss, logits = outputs[:2] # Logits is the actual output. Probabilities between 0 and 1.
                    eval_loss += tmp_eval_loss.mean().item()
                    nb_eval_steps += 1
                    # Concatenate all outputs to evaluate in the end.
                    if preds is None:
                        preds = logits.detach().cpu().numpy() # PRedictions into numpy mode
                        out_label_ids = inputs['labels'].detach().cpu().numpy().flatten() # Labels assigned by model
                    else:
                        batch_predictions = logits.detach().cpu().numpy()
                        preds = np.append(preds, batch_predictions, axis=0)
                        out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy().flatten(), axis=0)
                eval_loss = eval_loss / nb_eval_steps
            results = {}
            results["ROC Dev"] = roc_auc_score(out_label_ids, preds[:, 1])
            preds = np.argmax(preds, axis=1)
            results["Acuracy Dev"] = accuracy_score(out_label_ids, preds)
            results["F1 Dev"] = f1_score(out_label_ids, preds)
            results["AP Dev"] = average_precision_score(out_label_ids, preds)
            tqdm.write("***** Eval results *****")
            for key in sorted(results.keys()):
                tqdm.write(f"  {key} = {str(results[key])}")
            output_dir = path(f"checkpoints/checkpoint-{global_step}")
            if not os.path.isdir(output_dir):
                os.makedirs(path(output_dir))
#             print(f"Saving model checkpoint to {output_dir}")
            model_to_save = model.module if hasattr(model, 'module') else model
            model_to_save.save_pretrained(output_dir)

# Save final model 
output_dir = path(f"models/distilBERT-{str(datetime.date.today())}")
if not os.path.isdir(output_dir):
    os.makedirs(path(output_dir))
model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained(output_dir)

HBox(children=(FloatProgress(value=0.0, description='Epochs', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=21021.0, style=ProgressStyle(description_wi…

Train ROC: 0.6017579445571333
Train accuracy: 0.5875
Training loss: 0.6824558973312378
Learning rate: 1.189286903572618e-08
Train ROC: 0.7813390313390314
Train accuracy: 0.725
Training loss: 0.5217803120613098
Learning rate: 5.9583273868988165e-06


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.4734283789961621
  Acuracy Dev = 0.7344386149003148
  F1 Dev = 0.527800582132995
  ROC Dev = 0.7823920346899583
Train ROC: 0.8581818181818182
Train accuracy: 0.8125
Training loss: 0.4277583062648773
Learning rate: 1.1904761904761907e-05
Train ROC: 0.8382913806254768
Train accuracy: 0.825
Training loss: 0.42712002992630005
Learning rate: 1.7851196422624997e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5346556823798122
  Acuracy Dev = 0.7642812172088143
  F1 Dev = 0.6432926829268293
  ROC Dev = 0.8302346199964222
Train ROC: 0.8413696715583509
Train accuracy: 0.85
Training loss: 0.450906366109848
Learning rate: 2.3797630940488087e-05
Train ROC: 0.8988095238095238
Train accuracy: 0.85
Training loss: 0.37316620349884033
Learning rate: 2.9744065458351178e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5421112003521867
  Acuracy Dev = 0.7692759706190976
  F1 Dev = 0.6513604363544111
  ROC Dev = 0.8383614184643108
Train ROC: 0.7827208252740168
Train accuracy: 0.725
Training loss: 0.5373059511184692
Learning rate: 3.5690499976214264e-05
Train ROC: 0.8839031339031339
Train accuracy: 0.7875
Training loss: 0.39931368827819824
Learning rate: 4.1636934494077355e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5487017845350912
  Acuracy Dev = 0.7785099685204617
  F1 Dev = 0.6388831862040649
  ROC Dev = 0.8454836767280532
Train ROC: 0.7886108714408974
Train accuracy: 0.8375
Training loss: 0.4352118670940399
Learning rate: 4.7583369011940445e-05
Train ROC: 0.7998575498575499
Train accuracy: 0.7875
Training loss: 0.5039113759994507
Learning rate: 4.9607799607799605e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5533506899824174
  Acuracy Dev = 0.7814480587618048
  F1 Dev = 0.6433317350503459
  ROC Dev = 0.8502054766730774
Train ROC: 0.8254985754985754
Train accuracy: 0.7625
Training loss: 0.48662814497947693
Learning rate: 4.894708466137037e-05
Train ROC: 0.8504542278127183
Train accuracy: 0.775
Training loss: 0.4507910907268524
Learning rate: 4.828636971494114e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5661902327233137
  Acuracy Dev = 0.7871983210912906
  F1 Dev = 0.6679329316216925
  ROC Dev = 0.8568459000104849
Train ROC: 0.8676761026991442
Train accuracy: 0.7625
Training loss: 0.4336676597595215
Learning rate: 4.762565476851191e-05
Train ROC: 0.9045424621461489
Train accuracy: 0.8125
Training loss: 0.41981038451194763
Learning rate: 4.696493982208268e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5595527075761935
  Acuracy Dev = 0.7861070304302203
  F1 Dev = 0.6433370660694289
  ROC Dev = 0.8604986653775032
Train ROC: 0.8386666666666668
Train accuracy: 0.7875
Training loss: 0.4822412431240082
Learning rate: 4.630422487565345e-05
Train ROC: 0.8916363636363637
Train accuracy: 0.8125
Training loss: 0.3781545162200928
Learning rate: 4.5643509929224216e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5685093263044734
  Acuracy Dev = 0.7894648478488983
  F1 Dev = 0.6661341853035144
  ROC Dev = 0.8612058401817764
Train ROC: 0.8610928242264647
Train accuracy: 0.8
Training loss: 0.4607420563697815
Learning rate: 4.498279498279498e-05
Train ROC: 0.8792727272727273
Train accuracy: 0.8125
Training loss: 0.40605759620666504
Learning rate: 4.4322080036365746e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5709956259110867
  Acuracy Dev = 0.7858551941238195
  F1 Dev = 0.6880655416972365
  ROC Dev = 0.8635486376525725
Train ROC: 0.8839031339031338
Train accuracy: 0.8125
Training loss: 0.39896130561828613
Learning rate: 4.366136508993652e-05
Train ROC: 0.925595238095238
Train accuracy: 0.9
Training loss: 0.30391794443130493
Learning rate: 4.300065014350729e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5805320459245902
  Acuracy Dev = 0.7950052465897167
  F1 Dev = 0.6871236386931455
  ROC Dev = 0.8677017962802882
Train ROC: 0.8891369047619048
Train accuracy: 0.7875
Training loss: 0.38106560707092285
Learning rate: 4.233993519707805e-05
Train ROC: 0.8736299161831076
Train accuracy: 0.8125
Training loss: 0.43353912234306335
Learning rate: 4.167922025064882e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5843719330785806
  Acuracy Dev = 0.7976495278069256
  F1 Dev = 0.6896285328011331
  ROC Dev = 0.8690233542170219
Train ROC: 0.9032882011605416
Train accuracy: 0.8625
Training loss: 0.4046548008918762
Learning rate: 4.101850530421959e-05
Train ROC: 0.8981191222570533
Train accuracy: 0.875
Training loss: 0.3345804214477539
Learning rate: 4.035779035779036e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5732463267039056
  Acuracy Dev = 0.7840923399790136
  F1 Dev = 0.6979448032883148
  ROC Dev = 0.8687817549867656
Train ROC: 0.9106666666666666
Train accuracy: 0.8125
Training loss: 0.39230263233184814
Learning rate: 3.9697075411361125e-05
Train ROC: 0.9059065934065933
Train accuracy: 0.8
Training loss: 0.39423346519470215
Learning rate: 3.9036360464931894e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5861583283817675
  Acuracy Dev = 0.8003357817418678
  F1 Dev = 0.6820399705902012
  ROC Dev = 0.8744936155128205
Train ROC: 0.8058181818181818
Train accuracy: 0.775
Training loss: 0.47148066759109497
Learning rate: 3.837564551850266e-05
Train ROC: 0.8763636363636363
Train accuracy: 0.8
Training loss: 0.4132147431373596
Learning rate: 3.7714930572073424e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5907510539538373
  Acuracy Dev = 0.8018887722980063
  F1 Dev = 0.6939834024896266
  ROC Dev = 0.8763013368041515
Train ROC: 0.8833010960670535
Train accuracy: 0.7875
Training loss: 0.4441492259502411
Learning rate: 3.705421562564419e-05
Train ROC: 0.8772321428571428
Train accuracy: 0.8125
Training loss: 0.3992480933666229
Learning rate: 3.639350067921497e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5875612980804374
  Acuracy Dev = 0.7983630640083945
  F1 Dev = 0.6980894922071392
  ROC Dev = 0.8744116726579988
Train ROC: 0.8689927583936801
Train accuracy: 0.8
Training loss: 0.43431875109672546
Learning rate: 3.5732785732785736e-05
Train ROC: 0.9059561128526645
Train accuracy: 0.825
Training loss: 0.3638037443161011
Learning rate: 3.50720707863565e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5895932620272767
  Acuracy Dev = 0.8019727177334732
  F1 Dev = 0.6881692002643754
  ROC Dev = 0.8770030138703295
Train ROC: 0.9283865401207938
Train accuracy: 0.875
Training loss: 0.38181766867637634
Learning rate: 3.4411355839927267e-05
Train ROC: 0.9494301994301995
Train accuracy: 0.9
Training loss: 0.2794446647167206
Learning rate: 3.3750640893498035e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5908281271775208
  Acuracy Dev = 0.8036096537250788
  F1 Dev = 0.6819386853375025
  ROC Dev = 0.8776348415199899
Train ROC: 0.8875
Train accuracy: 0.8375
Training loss: 0.3492681384086609
Learning rate: 3.3089925947068804e-05
Train ROC: 0.9163636363636364
Train accuracy: 0.85
Training loss: 0.3378193974494934
Learning rate: 3.242921100063957e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5926705793325872
  Acuracy Dev = 0.8010073452256034
  F1 Dev = 0.7050332856342936
  ROC Dev = 0.8772818130083159
Train ROC: 0.8832417582417582
Train accuracy: 0.8125
Training loss: 0.39545726776123047
Learning rate: 3.176849605421034e-05
Train ROC: 0.8398437500000001
Train accuracy: 0.725
Training loss: 0.4513436257839203
Learning rate: 3.110778110778111e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5897693029618618
  Acuracy Dev = 0.8036935991605456
  F1 Dev = 0.6719045948789898
  ROC Dev = 0.8792138735543703
Train ROC: 0.9107505070993915
Train accuracy: 0.8125
Training loss: 0.37642183899879456
Learning rate: 3.0447066161351874e-05
Train ROC: 0.8763440860215054
Train accuracy: 0.8125
Training loss: 0.3669411838054657
Learning rate: 2.9786351214922643e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5984480191476851
  Acuracy Dev = 0.8065897166841552
  F1 Dev = 0.7008180755746007
  ROC Dev = 0.8799395075989522
Train ROC: 0.9187979539641944
Train accuracy: 0.825
Training loss: 0.3746066093444824
Learning rate: 2.912563626849341e-05
Train ROC: 0.9016927083333333
Train accuracy: 0.8625
Training loss: 0.3908333480358124
Learning rate: 2.846492132206418e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5912168546520766
  Acuracy Dev = 0.802392444910808
  F1 Dev = 0.693169968717414
  ROC Dev = 0.8808847776677455
Train ROC: 0.9014675052410901
Train accuracy: 0.85
Training loss: 0.3724188506603241
Learning rate: 2.7804206375634945e-05



HBox(children=(FloatProgress(value=0.0, description='Batches', max=21021.0, style=ProgressStyle(description_wi…

Train ROC: 0.9216666666666667
Train accuracy: 0.875
Training loss: 0.3073306977748871
Learning rate: 2.7776456347884916e-05
Train ROC: 0.9636617749825297
Train accuracy: 0.875
Training loss: 0.23995323479175568
Learning rate: 2.7115741401455684e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.598748947823059
  Acuracy Dev = 0.8066736621196222
  F1 Dev = 0.7016839378238342
  ROC Dev = 0.8803806782894426
Train ROC: 0.8902340597255851
Train accuracy: 0.8625
Training loss: 0.37392547726631165
Learning rate: 2.6455026455026456e-05
Train ROC: 0.8761755485893418
Train accuracy: 0.8125
Training loss: 0.3874090909957886
Learning rate: 2.5794311508597225e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…

Train ROC: 0.8923636363636364
Train accuracy: 0.8375
Training loss: 0.3760998845100403
Learning rate: 2.4472881615738758e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.6014140481669834
  Acuracy Dev = 0.808604407135362
  F1 Dev = 0.7019218198457315
  ROC Dev = 0.8818576672311645
Train ROC: 0.9086538461538461
Train accuracy: 0.8125
Training loss: 0.3471875488758087
Learning rate: 2.3812166669309527e-05
Train ROC: 0.9189378057302586
Train accuracy: 0.85
Training loss: 0.33646488189697266
Learning rate: 2.3151451722880292e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.5964450211995388
  Acuracy Dev = 0.804784889821616
  F1 Dev = 0.7022978941304487
  ROC Dev = 0.8792647008763546
Train ROC: 0.9585454545454545
Train accuracy: 0.9125
Training loss: 0.2510432004928589
Learning rate: 2.249073677645106e-05
Train ROC: 0.9397321428571428
Train accuracy: 0.8875
Training loss: 0.2831451892852783
Learning rate: 2.183002183002183e-05

***** Eval results *****
  AP Dev = 0.5944969588656094
  Acuracy Dev = 0.8054564533053515
  F1 Dev = 0.6881517863150104
  ROC Dev = 0.8814327519292304
Train ROC: 0.9425287356321839
Train accuracy: 0.875
Training loss: 0.3084748685359955
Learning rate: 2.1169306883592597e-05
Train ROC: 0.9017857142857144
Train accuracy: 0.825
Training loss: 0.35808029770851135
Learning rate: 2.0508591937163366e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.6040351162609724
  Acuracy Dev = 0.8090661070304302
  F1 Dev = 0.7109727428680347
  ROC Dev = 0.8843723809937099
Train ROC: 0.9475890985324947
Train accuracy: 0.8875
Training loss: 0.2689463794231415
Learning rate: 1.8526447097875668e-05
Train ROC: 0.91015625
Train accuracy: 0.875
Training loss: 0.3051362633705139
Learning rate: 1.7865732151446436e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.605562752790373
  Acuracy Dev = 0.809485834207765
  F1 Dev = 0.7148690244362083
  ROC Dev = 0.8858565483086762
Train ROC: 0.9546703296703296
Train accuracy: 0.875
Training loss: 0.26548585295677185
Learning rate: 1.7205017205017205e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…

Train ROC: 0.9007686932215233
Train accuracy: 0.85
Training loss: 0.3778061866760254
Learning rate: 1.5222872365729507e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.6049565812285839
  Acuracy Dev = 0.8114165792235047
  F1 Dev = 0.6987192382485079
  ROC Dev = 0.8856391200999794
Train ROC: 0.9308333333333334
Train accuracy: 0.825
Training loss: 0.3253900408744812
Learning rate: 1.4562157419300276e-05
Train ROC: 0.9716024340770791
Train accuracy: 0.875
Training loss: 0.22869662940502167
Learning rate: 1.3901442472871042e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.6049439806644868
  Acuracy Dev = 0.8080167890870934
  F1 Dev = 0.7188345217605114
  ROC Dev = 0.88657091337748
Train ROC: 0.8993710691823898
Train accuracy: 0.825
Training loss: 0.3934527039527893
Learning rate: 1.3240727526441813e-05


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.6062855956509797
  Acuracy Dev = 0.8112486883525708
  F1 Dev = 0.7080438875543725
  ROC Dev = 0.887348135786378
Train ROC: 0.9282990083905416
Train accuracy: 0.8625
Training loss: 0.31258562207221985
Learning rate: 1.1919297633583346e-05
Train ROC: 0.880184331797235
Train accuracy: 0.775
Training loss: 0.4340452253818512
Learning rate: 1.1258582687154115e-05


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.6085985056226089
  Acuracy Dev = 0.812507869884575
  F1 Dev = 0.7108924988673871
  ROC Dev = 0.8874412128490645
Train ROC: 0.889397406559878
Train accuracy: 0.825
Training loss: 0.37193986773490906
Learning rate: 1.0597867740724883e-05
Train ROC: 0.9560931899641578
Train accuracy: 0.875
Training loss: 0.2541200816631317
Learning rate: 9.937152794295652e-06


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.6081039524874794
  Acuracy Dev = 0.810912906610703
  F1 Dev = 0.717749514441451
  ROC Dev = 0.8877547939998727
Train ROC: 0.8971354166666667
Train accuracy: 0.8
Training loss: 0.3900657892227173
Learning rate: 9.276437847866419e-06


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Train ROC: 0.98
Train accuracy: 0.8625
Training loss: 0.24718227982521057
Learning rate: 8.615722901437187e-06


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.6097731075456313
  Acuracy Dev = 0.8122560335781742
  F1 Dev = 0.7176314626601856
  ROC Dev = 0.8886919938489388
Train ROC: 0.9105454545454545
Train accuracy: 0.825
Training loss: 0.37262165546417236
Learning rate: 7.955007955007956e-06
Train ROC: 0.9341692789968652
Train accuracy: 0.8375
Training loss: 0.29048022627830505
Learning rate: 7.294293008578723e-06


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.6086857166213097
  Acuracy Dev = 0.8128436516264428
  F1 Dev = 0.7088855520010445
  ROC Dev = 0.8885818646815047
Train ROC: 0.9749784296807593
Train accuracy: 0.9125
Training loss: 0.20650334656238556
Learning rate: 6.633578062149491e-06
Train ROC: 0.9628127112914131
Train accuracy: 0.9125
Training loss: 0.25391659140586853
Learning rate: 5.9728631157202585e-06


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…

Train ROC: 0.9587242026266416
Train accuracy: 0.8875
Training loss: 0.30974555015563965
Learning rate: 4.651433222861795e-06
Train ROC: 0.9170909090909091
Train accuracy: 0.85
Training loss: 0.33698007464408875
Learning rate: 3.33000333000333e-06


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…

Train ROC: 0.9287272727272726
Train accuracy: 0.875
Training loss: 0.3170314431190491
Learning rate: 2.0085734371448653e-06


HBox(children=(FloatProgress(value=0.0, description='Dev batch', max=745.0, style=ProgressStyle(description_wi…


***** Eval results *****
  AP Dev = 0.6107921747560067
  Acuracy Dev = 0.8133473242392445
  F1 Dev = 0.715864800971184
  ROC Dev = 0.8894114087507152


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [11]:
model

DataParallel(
  (module): DistilBertForSequenceClassification(
    (distilbert): DistilBertModel(
      (embeddings): Embeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (transformer): Transformer(
        (layer): ModuleList(
          (0): TransformerBlock(
            (attention): MultiHeadSelfAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (q_lin): Linear(in_features=768, out_features=768, bias=True)
              (k_lin): Linear(in_features=768, out_features=768, bias=True)
              (v_lin): Linear(in_features=768, out_features=768, bias=True)
              (out_lin): Linear(in_features=768, out_features=768, bias=True)
            )
            (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (ffn): F