In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [3]:
from nlstruct.text import huggingface_tokenize, regex_sentencize, partition_spans, encode_as_tag, split_into_spans, apply_substitutions, apply_deltas
from nlstruct.dataloaders import load_from_brat, load_genia_ner
from nlstruct.collections import Dataset, Batcher
from nlstruct.utils import merge_with_spans, normalize_vocabularies, factorize_rows, df_to_csr, factorize, torch_global as tg
from nlstruct.modules.crf import BIODecoder, BIOULDecoder
from nlstruct.environment import root, cached
from nlstruct.train import seed_all
from itertools import chain, repeat

import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
import math

import pandas as pd
import numpy as np
import re
import string
from transformers import AutoModel, AutoTokenizer

In [4]:
pd.set_option('display.width', 1000)

In [5]:
import scipy

def select_closest_non_overlapping_gold_mentions(
    gold_ids,
    gold_sentence_ids,
    gold_begins,
    gold_ends,
    
    pred_sentence_ids,
    pred_begins,
    pred_ends,
    
    zone_mention_id,
    zone_mask,
    gold_conflicts,
    gold_conflicts_mask,
):
    """
    Select non overlapping gold mentions (in gold_ids) that are the closest to those found by the model
    Gold mentions are described by gold_* tensors, predicted mentions are described by pred_* tensors
    We select at most one gold mention per zone (zone_mention_id + zone_mask) and each time a gold mention
    is selected, its overlaps are removed according to the gold_conflicts + gold_conflicts_mask tensors
    
    Returns
    -------
    torch.Tensor
        Selected gold ids, included in "gold_ids"
    """

    device = gold_begins.device
    
    if len(gold_ids) == 0:
        return torch.as_tensor([], dtype=torch.long, device=device)

    [rel], [remaining_mask], _ = factorize([zone_mention_id], [zone_mask], reference_values=gold_ids)
    remaining_mentions = zone_mention_id[remaining_mask.any(1)]
    rel = rel[remaining_mask.any(1)]
    remaining_mask = remaining_mask[remaining_mask.any(1)]
        
    keep_mask = torch.zeros(gold_ids.max()+1, device=device, dtype=torch.bool)
    zone_scores = torch.full(remaining_mask.shape, fill_value=-1, device=device, dtype=torch.float)
    
    if len(pred_begins):
        PRED, GOLD = 0, 1
        SENTENCE_ID, BEGIN, END = 0, 1, 2
        p = torch.stack([pred_sentence_ids, pred_begins, pred_ends], dim=0).unsqueeze(GOLD+1)
        g = torch.stack([gold_sentence_ids, gold_begins, gold_ends], dim=0).unsqueeze(PRED+1)

        overlap = (torch.min(p[END], g[END]) - torch.max(p[BEGIN], g[BEGIN])).float().clamp(0)
        overlap = overlap * 2 / (p[END] - p[BEGIN] + g[END] - g[BEGIN])
        score = (p[SENTENCE_ID] == g[SENTENCE_ID]) * overlap
        
        zone_scores = score.max(0).values[rel]
        zone_scores[~remaining_mask] = -1
    else:
        zone_scores[remaining_mask] = 0
    while len(remaining_mask):
        best_indexer = torch.arange(zone_scores.shape[0], device=device), zone_scores.argmax(1)
        best_mentions = remaining_mentions[best_indexer]
        conflicts = gold_conflicts[best_mentions][gold_conflicts_mask[best_mentions]]    
        keep_mask[best_mentions] = True
        remaining_mask[best_indexer] = False
        remaining_mask &= ~(remaining_mentions.unsqueeze(-1) == conflicts).any(-1)
        zone_scores[~remaining_mask] = -1
        zone_scores = zone_scores[remaining_mask.any(1)]
        remaining_mentions = remaining_mentions[remaining_mask.any(1)]
        remaining_mask = remaining_mask[remaining_mask.any(1)]
    return keep_mask.nonzero()[:, 0]

def split_zone_mentions(batch, random_perm=True, observed_zone_sizes=None):
    """
    In a batch, splits mentions between 
    - those that we will consider as being observed
    - and those that we will ask the model to recover
    
    Parameters
    ----------
    random_perm: bool
        Shuffle the mentions before splitting them
    observed_zone_sizes: int
        If not None, selects exactly this number of mentions per zone (=overlapping group of mentions)
        Otherwise, any random number from 0 to the maximum of mentions can be observed in each group
    
    Returns
    -------
    (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor)
        - observed_mentions: flat observed mentions (@mention_id)
        - target_mentions: flat target mentions (@mention_id)
        - zone_target_mentions: mentions to recover, grouped by zone (n_zones * n_mentions_per_zone)
        - target_mask: mask of zone_target_mentions, since every zone can have a different number of picked target mentions
    """
    zone_mention_id = batch["zone", "@mention_id"]
    zone_mention_mask = batch["zone", "mention_mask"]
    n_sentences = len(batch["sentence"])
    device = zone_mention_id.device
    if random_perm:
        perm = torch.rand(zone_mention_id.shape, device=device)
    else:
        perm = torch.zeros(zone_mention_id.shape, device=device, dtype=torch.float)
    perm[~zone_mention_mask] = 2
    perm = perm.argsort(1)

    if observed_zone_sizes is None:
        observed_zone_size = ((zone_mention_mask.sum(-1) + 1) * torch.rand(zone_mention_mask.shape[0], dtype=torch.float, device=device)).long()
    else:
        observed_zone_size = torch.full((zone_mention_mask.shape[0],), fill_value=observed_zone_sizes, device=device, dtype=torch.long)

    # Select mentions that will become features
    zone_observed_mentions = zone_mention_id[torch.arange(perm.shape[0], device=device).unsqueeze(1), perm]
    observed_mask = (torch.arange(zone_mention_mask.shape[1], device=device).unsqueeze(0) < observed_zone_size.unsqueeze(1)) & zone_mention_mask
    observed_mentions = zone_observed_mentions[observed_mask]

    # Select mentions that will be hidden from the model (ie to recover)
    zone_target_mentions = zone_mention_id[torch.arange(perm.shape[0], device=device).unsqueeze(1), perm]
    target_mask = (torch.arange(zone_mention_mask.shape[1], device=device).unsqueeze(0) >= observed_zone_size.unsqueeze(1)) & zone_mention_mask
    zone_target_mentions = zone_target_mentions[target_mask.any(1)]
    target_mask = target_mask[target_mask.any(1)]
    target_mentions = zone_target_mentions[target_mask]
    return target_mentions, observed_mentions, zone_target_mentions, target_mask

def compute_scores(pred_batcher, gold_batcher, queries={}, prefix='val_', verbose=0):
    pred=pd.DataFrame(dict(pred_batcher["mention", ["sentence_id", "begin", "end", "ner_label", "mention_id"]]))
    gold=pd.DataFrame(dict(gold_batcher["mention", ["@zone_id", "begin", "end", "ner_label", "mention_id"]]))
    gold["sentence_id"] = gold_batcher["zone", "sentence_id"][gold["@zone_id"]]

    # Merge on spans and ner_label
    merged = merge_pred_and_gold(
        pred, gold, span_policy='exact',  # only partially match spans with strict bounds, we could also eval with 'exact' or 'partial'
        on=["sentence_id", ("begin", "end"), "ner_label"], atom_gold_level=["mention_id"], atom_pred_level=["mention_id"])
    
    merged["ner_label"] = np.asarray(vocs["ner_label"])[merged["ner_label"]].astype(str)
    metrics = {
        **compute_metrics(merged, prefix=prefix),
        #**compute_metrics(merged.query("ner_label == {}".format(list(vocs["ner_label"]).index(c["sosy"]))), prefix=prefix+"sosy_"),
        #**compute_metrics(merged.query("ner_label == {}".format(list(vocs["ner_label"]).index(c["pathologie"]))), prefix=prefix+"pathologie_"),
    }
    for name, query in queries.items():
        metrics.update(compute_metrics(merged.query(query), prefix=prefix+name+"_"))
    return metrics

In [6]:
# To debug the training, we can just comment the "def run_epoch()" and execute the function body manually without changing anything to it
def extract_mentions(batcher, all_nets, max_depth=10):
    """
    Parameters
    ----------
    batcher: Batcher 
        The batcher containing the text from which we want to extract the mentions (and maybe the gold mentions)
    ner_net: torch.nn.Module
    max_depth: int
        Max number of times we run the model per sample
        
    Returns
    -------
    Batcher
    """
    pred_batches = []
    n_mentions = 0
    ner_net = all_nets["ner_net"]
    tag_embeddings = all_nets["tag_embeddings"]
    with evaluating(all_nets):
        with torch.no_grad():
            for batch_i, batch in enumerate(batcher['sentence'].dataloader(batch_size=batch_size, shuffle=False, sparse_sort_on="token_mask", device=tg.device)):

                tag_embeds = torch.zeros(*batch["sentence", "token"].shape[:2], tag_dim, device=tg.device)
                current_sentences_idx = torch.arange(len(batch), device=tg.device)
                mask = batch["token_mask"]
                tokens = batch["token"]

                for i in range(max_depth):
                    # Run the model argmax here
                    ner_res = ner_net(
                        tokens = tokens,
                        mask = mask,
                        tag_embeds = tag_embeds,
                        return_embeddings=True
                    )

                    # Run the linear CRF Viterbi algorithm to compute the most likely sequence
                    pred_tags = ner_net.crf.decode(ner_res["scores"], mask)
                    spans = ner_net.crf.tags_to_spans(pred_tags, mask)

                    # Save predicted mentions
                    pred_batch = Batcher({
                        "mention": {
                            "mention_id": torch.arange(n_mentions, n_mentions+len(spans["span_doc_id"]), device=tg.device),
                            "begin": spans["span_begin"],
                            "end": spans["span_end"],
                            "ner_label": spans["span_label"],
                            "@sentence_id": current_sentences_idx[spans["span_doc_id"]],
                            "depth": torch.full_like(spans["span_begin"], fill_value=i),
                        },
                        "sentence": dict(batch["sentence", ["sentence_id", "doc_id"]]),
                        "doc": dict(batch["doc"])}, 
                        check=False).sparsify()
                    pred_batches.append(pred_batch)
                    n_mentions += len(spans["span_doc_id"])

                    non_empty_sentences = torch.unique(spans["span_doc_id"])

                    if len(non_empty_sentences) == 0:
                        break

                    # Convert the predicted spans to tags using the same encoding scheme as the one used to decode predicted tags
                    # (We could use a different one: BIODecoder/BIOULDecoder.spans_to_tags is a static function)
                    feature_tags = ner_net.crf.spans_to_tags(
                        torch.arange(len(spans["span_begin"]), device=spans["span_begin"].device),
                        spans["span_begin"], 
                        spans["span_end"],
                        spans["span_label"], 
                        n_tokens=batch["sentence", "token"].shape[1],
                        n_samples=len(spans["span_begin"]),
                    )
                    tag_mention, tag_positions = feature_tags.nonzero(as_tuple=True)
                    tag_sentence = spans["span_doc_id"][tag_mention]
                    tag_values = feature_tags[tag_mention, tag_positions]

                    tag_embeds = tag_embeds.view(-1, tag_dim).index_add_(
                        dim=0,
                        index=tag_sentence * batch["sentence", "token"].shape[1] + tag_positions, 
                        source=tag_embeddings.weight[tag_values-1]).view(len(current_sentences_idx), batch["sentence", "token"].shape[1], tag_dim)[non_empty_sentences]

                    # Compute the tokens label tag embeddings of the observed (maybe overlapping) mentions
                    tokens = tokens[non_empty_sentences]
                    mask = mask[non_empty_sentences]
                    current_sentences_idx = current_sentences_idx[non_empty_sentences]
    return Batcher.concat(pred_batches)

In [7]:
from collections import defaultdict

# Define the training metrics
metrics_info = defaultdict(lambda: False)
flt_format = (5, "{:.4f}".format)
metrics_info.update({
    "train_loss": {"goal": 0, "format": flt_format},
    "train_ner_loss": {"goal": 0, "format": flt_format},
    #"train_recall": {"goal": 1, "format": flt_format, "name": "train_rec"},
    #"train_precision": {"goal": 1, "format": flt_format, "name": "train_prec"},
    "train_f1": {"goal": 1, "format": flt_format, "name": "train_f1"},
    
    "val_loss": {"goal": 0, "format": flt_format},
    "val_ner_loss": {"goal": 0, "format": flt_format},
    "val_label_loss": {"goal": 0, "format": flt_format},
    
    "val_f1": {"goal": 1, "format": flt_format, "name": "val_f1"},
    "val_3.1_f1": {"goal": 1, "format": flt_format, "name": "val_3.1_f1"},
    "val_3.2_f1": {"goal": 1, "format": flt_format, "name": "val_3.2_f1"},
    "val_macro_f1": {"goal": 1, "format": flt_format, "name": "val_macro_f1"},
    "val_sosy_f1": {"goal": 1, "format": flt_format, "name": "val_sosy_f1"},
    "val_pathologie_f1": {"goal": 1, "format": flt_format, "name": "val_patho_f1"},
    
    "duration": {"format": flt_format, "name": "   dur(s)"},
    "rescale": {"format": flt_format},
    "n_depth": {"format": flt_format},
    "n_matched": {"format": flt_format},
    "n_targets": {"format": flt_format},
    "n_observed": {"format": flt_format},
    "total_score_sum": {"format": flt_format},
    "lr": {"format": (5, "{:.2e}".format)},
})

In [8]:
def make_batcher(docs, sentences, zones, mentions, conflicts, tokens):
    """
    Parameters:
    ----------
    docs: pd.DataFrame
    sentences: pd.DataFrame
    zones: pd.DataFrame
    mentions: pd.DataFrame
    conflicts: pd.DataFrame
    tokens: pd.DataFrame
    
    Returns
    -------
    Batcher
    """
    docs = docs.copy()
    sentences = sentences.copy()
    zones = zones.copy()
    mentions = mentions.copy()
    conflicts = conflicts.copy()
    tokens = tokens.copy()
    
    [tokens["token_id"]], unique_token_id = factorize_rows([tokens["token_id"]])
    [mentions["mention_id"], conflicts["mention_id"], conflicts["mention_id_other"]], unique_mention_ids = factorize_rows(
        [mentions[["doc_id", "sentence_id", "mention_id"]], conflicts[["doc_id", "sentence_id", "mention_id"]], conflicts[["doc_id", "sentence_id", "mention_id_other"]]])
    [zones["zone_id"], mentions["zone_id"]], unique_zone_ids = factorize_rows(
        [zones[["doc_id", "sentence_id", "zone_id"]], mentions[["doc_id", "sentence_id", "zone_id"]]])
    [sentences["sentence_id"], zones["sentence_id"], mentions["sentence_id"], tokens["sentence_id"],], unique_sentence_ids = factorize_rows(
        [sentences[["doc_id", "sentence_id"]], zones[["doc_id", "sentence_id"]], mentions[["doc_id", "sentence_id"]], tokens[["doc_id", "sentence_id"]]])
    [docs["doc_id"], sentences["doc_id"], zones["doc_id"], mentions["doc_id"], tokens["doc_id"]], unique_doc_ids = factorize_rows(
        [docs["doc_id"], sentences["doc_id"], zones["doc_id"], mentions["doc_id"], tokens["doc_id"]])
    
    batcher = Batcher({
        "mention": {
            "mention_id": mentions["mention_id"],
            "zone_id": mentions["zone_id"],
            "sentence_id": mentions["sentence_id"],
            "doc_id": mentions["doc_id"],
            "begin": mentions["begin"],
            "end": mentions["end"],
            "ner_label": mentions["ner_label"].cat.codes,
            "conflict_mention_id": df_to_csr(conflicts["mention_id"], conflicts["conflict_idx"], conflicts["mention_id_other"], n_rows=len(unique_mention_ids)),
            "conflict_mask": df_to_csr(conflicts["mention_id"], conflicts["conflict_idx"], n_rows=len(unique_mention_ids)),
        },
        "zone": {
            "zone_id": zones["zone_id"],
            "sentence_id": zones["sentence_id"],
            "doc_id": zones["doc_id"],
            "mention_id": df_to_csr(mentions["zone_id"], mentions["zone_mention_idx"], mentions["mention_id"], n_rows=len(unique_zone_ids)),
            "mention_mask": df_to_csr(mentions["zone_id"], mentions["zone_mention_idx"], n_rows=len(unique_zone_ids)),
        },
        "sentence": {
            "sentence_id": sentences["sentence_id"],
            "doc_id": sentences["doc_id"],
            "mention_id": df_to_csr(mentions["sentence_id"], mentions["mention_idx"], mentions["mention_id"], n_rows=len(unique_sentence_ids)),
            "mention_mask": df_to_csr(mentions["sentence_id"], mentions["mention_idx"], n_rows=len(unique_sentence_ids)),
            "token": df_to_csr(tokens["sentence_id"], tokens["token_idx"], tokens["token"].cat.codes, n_rows=len(unique_sentence_ids)),
            "token_mask": df_to_csr(tokens["sentence_id"], tokens["token_idx"], n_rows=len(unique_sentence_ids)),
            "zone_id": df_to_csr(zones["sentence_id"], zones["zone_idx"], zones["zone_id"], n_rows=len(unique_sentence_ids)),
            "zone_mask": df_to_csr(zones["sentence_id"], zones["zone_idx"], n_rows=len(unique_sentence_ids)),
        },
        "doc": {
            "doc_id": np.arange(len(unique_doc_ids)),
            "sentence_id": df_to_csr(sentences["doc_id"], sentences["sentence_idx"], sentences["sentence_id"], n_rows=len(unique_doc_ids)),
            "sentence_mask": df_to_csr(sentences["doc_id"], sentences["sentence_idx"], n_rows=len(unique_doc_ids)),
            "split": docs["split"].cat.codes,
        }},
        masks={"sentence": {"token": "token_mask", "zone_id": "zone_mask", "mention_id": "mention_mask"}, 
               "mention": {"conflict_mention_id": "conflict_mask"},
               "zone": {"mention_id": "mention_mask"}, 
               "doc": {"sentence_id": "sentence_mask"}}
    )
    return (
        batcher, 
        dict(docs=docs, sentences=sentences, zones=zones, mentions=mentions, tokens=tokens),
        dict(token_id=unique_token_id, mention_id=unique_mention_ids, zone_id=unique_zone_ids, sentence_id=unique_sentence_ids, doc_id=unique_doc_ids)
    )

In [9]:
class NERNet(torch.nn.Module):
    def __init__(self,
                 n_labels,
                 hidden_dim,
                 dropout,
                 n_tokens=None,
                 token_dim=None,
                 embeddings=None,
                 tag_scheme="bio",
                 metric='linear',
                 metric_fc_kwargs=None,
                 ):
        super().__init__()
        if embeddings is not None:
            self.embeddings = embeddings
            if n_tokens is None or token_dim is None:
                if hasattr(embeddings, 'weight'):
                    n_tokens, token_dim = embeddings.weight.shape
                else:
                    n_tokens, token_dim = embeddings.embeddings.weight.shape
        else:
            self.embeddings = torch.nn.Embedding(n_tokens, token_dim) if n_tokens > 0 else None
        assert token_dim is not None, "Provide token_dim or embeddings"
        assert self.embeddings is not None

        dim = (token_dim if n_tokens > 0 else 0)
        self.dropout = torch.nn.Dropout(dropout)
        if tag_scheme == "bio":
            self.crf = BIODecoder(n_labels)
        elif tag_scheme == "bioul":
            self.crf = BIOULDecoder(n_labels)
        else:
            raise Exception()
        if hidden_dim is None:
            hidden_dim = dim
        self.linear = torch.nn.Linear(dim, hidden_dim)
        self.batch_norm = torch.nn.BatchNorm1d(dim)

        n_tags = self.crf.num_tags
        metric_fc_kwargs = metric_fc_kwargs if metric_fc_kwargs is not None else {}
        if metric == "linear":
            self.metric_fc = torch.nn.Linear(dim, n_tags)
        elif metric == "cosine":
            self.metric_fc = CosineSimilarity(dim, n_tags, rescale=rescale, **metric_fc_kwargs)
        elif metric == "ema_cosine":
            self.metric_fc = EMACosineSimilarity(dim, n_tags, rescale=rescale, **metric_fc_kwargs)
        else:
            raise Exception()
    
    def extended_embeddings(self, tokens, mask, **kwargs):
        # Default case here, size <= 512
        # Small ugly check to see if self.embeddings is Bert-like, then we need to pass a mask
        if hasattr(self.embeddings, 'encoder') or hasattr(self.embeddings, 'transformer'):
            return self.embeddings(tokens, mask, **kwargs)[0]
        else:
            return self.embeddings(tokens)

    def forward(self, tokens, mask, tag_embeds=None, return_embeddings=False):
        # Embed the tokens
        scores = None
        # shape: n_batch * sequence * 768
        embeds = self.extended_embeddings(tokens, mask, custom_embeds=tag_embeds)
        state = embeds.masked_fill(~mask.unsqueeze(-1), 0)
        state = torch.relu(self.linear(self.dropout(state)))# + state
        state = self.batch_norm(state.view(-1, state.shape[-1])).view(state.shape)
        scores = self.metric_fc(state)
        return {
            "scores": scores,
            "embeddings": embeds if return_embeddings else None,
        }

In [10]:
class LSTMNERNet(torch.nn.Module):
    def __init__(self,
                 n_labels,
                 hidden_dim,
                 dropout,
                 n_tokens=None,
                 token_dim=None,
                 embeddings=None,
                 tag_scheme="bio",
                 metric='linear',
                 metric_fc_kwargs=None,
                 ):
        super().__init__()
        if embeddings is not None:
            self.embeddings = embeddings
            if n_tokens is None or token_dim is None:
                if hasattr(embeddings, 'weight'):
                    n_tokens, token_dim = embeddings.weight.shape
                else:
                    n_tokens, token_dim = embeddings.embeddings.weight.shape
        else:
            self.embeddings = torch.nn.Embedding(n_tokens, token_dim) if n_tokens > 0 else None
        assert token_dim is not None, "Provide token_dim or embeddings"
        assert self.embeddings is not None

        dim = (token_dim if n_tokens > 0 else 0)
        self.dropout = torch.nn.Dropout(dropout)
        if tag_scheme == "bio":
            self.crf = BIODecoder(n_labels)
        elif tag_scheme == "bioul":
            self.crf = BIOULDecoder(n_labels)
        else:
            raise Exception()
        if hidden_dim is None:
            hidden_dim = dim
        self.linear = torch.nn.Linear(dim, hidden_dim)
        self.batch_norm = torch.nn.BatchNorm1d(dim)
        self.lstm = torch.nn.LSTM(hidden_dim, 
                                  hidden_dim, dropout=dropout, batch_first=True, num_layers=2, bidirectional=True)
            

        n_tags = self.crf.num_tags
        metric_fc_kwargs = metric_fc_kwargs if metric_fc_kwargs is not None else {}
        if metric == "linear":
            self.metric_fc = torch.nn.Linear(hidden_dim*2, n_tags)
        elif metric == "cosine":
            self.metric_fc = CosineSimilarity(dim, n_tags, rescale=rescale, **metric_fc_kwargs)
        elif metric == "ema_cosine":
            self.metric_fc = EMACosineSimilarity(dim, n_tags, rescale=rescale, **metric_fc_kwargs)
        else:
            raise Exception()
    
    def extended_embeddings(self, tokens, mask, **kwargs):
        # Default case here, size <= 512
        # Small ugly check to see if self.embeddings is Bert-like, then we need to pass a mask
        if hasattr(self.embeddings, 'encoder') or hasattr(self.embeddings, 'transformer'):
            return self.embeddings(tokens, mask, **kwargs)[0]
        else:
            return self.embeddings(tokens)

    def forward(self, tokens, mask, tag_embeds=None, return_embeddings=False):
        # Embed the tokens
        scores = None
        # shape: n_batch * sequence * 768
        embeds = self.extended_embeddings(tokens, mask, custom_embeds=tag_embeds)
        state = embeds.masked_fill(~mask.unsqueeze(-1), 0)
        state = torch.relu(self.linear(self.dropout(state)))# + state
        state = self.batch_norm(state.view(-1, state.shape[-1])).view(state.shape)
        
        lstm_state = self.lstm(self.dropout(state))[0]
        state = torch.relu(lstm_state)
        scores = self.metric_fc(state)
        return {
            "scores": scores,
            "embeddings": embeds if return_embeddings else None,
        }

In [11]:
#@cached
def preprocess(
    dataset,
    max_sentence_length,
    bert_name,
    ner_labels=None,
    unknown_labels="drop",
    vocabularies=None,
):
    """
    Parameters
    ----------
        dataset: Dataset
        max_sentence_length: int
            Max number of "words" as defined by the regex in regex_sentencize (so this is not the nb of wordpieces)
        bert_name: str
            bert path/name
        ner_labels: list of str 
            allowed ner labels (to be dropped or filtered)
        unknown_labels: str
            "drop" or "raise"
        vocabularies: dict[str; np.ndarray or list]
        
    Returns
    -------
    (pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, dict[str; np.ndarray or list])
        docs:      ('split', 'text', 'doc_id')
        sentences: ('split', 'doc_id', 'sentence_idx', 'begin', 'end', 'text', 'sentence_id')
        zones:     ('doc_id', 'sentence_id', 'zone_id', 'zone_idx')
        mentions:  ('ner_label', 'doc_id', 'sentence_id', 'mention_id', 'depth', 'zone_id', 'text', 'mention_idx', 'begin', 'end', 'zone_mention_idx')
        conflicts: ('doc_id', 'sentence_id', 'mention_id', 'mention_id_other', 'conflict_idx')
        tokens:    ('split', 'token', 'sentence_id', 'token_id', 'token_idx', 'begin', 'end', 'doc_id', 'sentence_idx')
        deltas:    ('doc_id', 'begin', 'end', 'delta')
        vocs: vocabularies to be reused later for encoding more data or decoding predictions
    """
    print("Dataset:", dataset)

    mentions = dataset["mentions"].rename({"label": "ner_label"}, axis=1)

    
    if ner_labels is not None:
        len_before = len(mentions)
        unknown_ner_labels = list(mentions[~mentions["ner_label"].isin(ner_labels)]["ner_label"].drop_duplicates())
        mentions = mentions[mentions["ner_label"].isin(ner_labels)]
        
        if len(unknown_ner_labels) and unknown_labels == "raise":
            raise Exception(f"Unkown labels in {len_before-len(mentions)} mentions: ", unknown_ner_labels)

    # Check that there is no mention overlap
    mentions = mentions.merge(dataset["fragments"].groupby(["doc_id", "mention_id"], as_index=False, observed=True).agg({"begin": "min", "end": "max"}))

    
    print("Transform texts...", end=" ")
    transformed_docs, deltas = apply_substitutions(
        dataset["docs"], *zip(
            (r"(?<=[{}\\])(?![ ])".format(string.punctuation), r" "),
            (r"(?<![ ])(?=[{}\\])".format(string.punctuation), r" "),
            ("(?<=[a-zA-Z])(?=[0-9])", r" "),
            ("(?<=[0-9])(?=[A-Za-z])", r" "),
        ), apply_unidecode=True)
    transformed_docs = transformed_docs.astype({"text": str})
    transformed_mentions = apply_deltas(mentions, deltas, on=['doc_id'])
    print("done")
    
    print("Splitting into sentences...", end=" ")
    sentences = regex_sentencize(
        transformed_docs, 
        reg_split=r"((?:\s*\n\s*\n)+\s*|(?:(?<=[a-z0-9)]\n)|(?<=[a-z0-9)][ ](?:\.|\n))|(?<=[a-z0-9)][ ][ ](?:\.|\n)))\s*(?=[A-Z]))",
        min_sentence_length=0, max_sentence_length=max_sentence_length,
        # balance_parentheses=True, # default is True
    )
    
    [sentence_mentions], sentences, sentence_to_docs = partition_spans([transformed_mentions], sentences, new_id_name="sentence_id", overlap_policy=False)
#     n_sentences_per_mention = sentence_mentions.assign(count=1).groupby(["doc_id", "mention_id"], as_index=False).agg({"count": "sum", "text": "first", "sentence_id": "last"})
#     if n_sentences_per_mention["count"].max() > 1:
#         display(n_sentences_per_mention.query("count > 1"))
#         display(sentences[sentences["sentence_id"].isin(n_sentences_per_mention.query("count > 1")["sentence_id"])]["text"].tolist())
#         raise Exception("Some mentions could be mapped to more than 1 sentences ({})".format(n_sentences_per_mention["count"].max()))
    if sentence_to_docs is not None:
        sentence_mentions = sentence_mentions.merge(sentence_to_docs)
        
    sentence_mentions = sentence_mentions.assign(mention_idx=0).nlstruct.groupby_assign(["doc_id", "sentence_id"], {"mention_idx": lambda x: tuple(range(len(x)))})
    print("done")
    
    print("Tokenizing...", end=" ")
    tokenizer = AutoTokenizer.from_pretrained(bert_name)
    sentences["text"] = sentences["text"].str.lower()
    tokens = huggingface_tokenize(sentences, tokenizer, doc_id_col="sentence_id")
    sentence_mentions = split_into_spans(sentence_mentions, tokens, pos_col="token_idx", overlap_policy=False)
    print("done")
    
    print("Processing zones (overlapping areas)...", end=" ")
    # Extract overlapping spans
    conflicts = (
        merge_with_spans(sentence_mentions, sentence_mentions, on=["doc_id", "sentence_id", ("begin", "end")], how="outer", suffixes=("", "_other"))
    )

    # ids1, and ids2 make the edges of the overlapping mentions of the same type (see the "ner_label")
    [ids1, ids2], unique_ids = factorize_rows(
        [conflicts[["doc_id", "sentence_id", "mention_id"]], 
         conflicts[["doc_id", "sentence_id", "mention_id_other"]]],
        sentence_mentions.eval("size=(end-begin)").sort_values("size")[["doc_id", "sentence_id", "mention_id"]]
    )
    g = nx.from_scipy_sparse_matrix(df_to_csr(ids1, ids2, n_rows=len(unique_ids), n_cols=len(unique_ids)))
    colored_nodes = np.asarray(list(nx.coloring.greedy_color(g, strategy=keep_order).items()))
    unique_ids['depth'] = colored_nodes[:, 1][colored_nodes[:, 0].argsort()]
    zone_indices, mention_indices = zip(*chain.from_iterable(zip(repeat(zone_idx), zone) for zone_idx, zone in enumerate(nx.connected_components(g))))
    conflicts = conflicts[["doc_id", "sentence_id", "mention_id", "mention_id_other"]].assign(conflict_idx=0).nlstruct.groupby_assign(["doc_id", "sentence_id", "mention_id"], {"conflict_idx": lambda x: tuple(range(len(x)))})

    zone_mentions = pd.DataFrame({
        **unique_ids.iloc[list(mention_indices)],
        "zone_id": zone_indices,
    }).merge(sentence_mentions, on=["doc_id", "sentence_id", "mention_id"]).sort_values(["doc_id", "sentence_id", "zone_id", "mention_id"])
    zone_mentions = zone_mentions.assign(zone_mention_idx=0).nlstruct.groupby_assign(['doc_id', 'sentence_id', 'zone_id'], {"zone_mention_idx": lambda vec: tuple(np.arange(len(vec)))})
    sentence_zones = zone_mentions[["doc_id", "sentence_id", "zone_id"]].drop_duplicates()
    sentence_zones = sentence_zones.assign(zone_idx=0).nlstruct.groupby_assign(['doc_id', 'sentence_id'], {"zone_idx": lambda vec: tuple(np.arange(len(vec)))})
    sentence_mentions = sentence_mentions.merge(zone_mentions.drop_duplicates(["doc_id", "sentence_id", "mention_id", "depth"]))
    print("done")

    print("Computing vocabularies...")
    [transformed_docs, sentences, sentence_zones, zone_mentions, tokens], vocs = normalize_vocabularies(
        [transformed_docs, sentences, sentence_zones, zone_mentions, tokens], 
        vocabularies={"split": ["train", "val", "test"]} if vocabularies is None else vocabularies,
        train_vocabularies={"source": False, "text": False} if vocabularies is None else False,
        verbose=True)
    print("done")
    return transformed_docs, sentences, sentence_zones, zone_mentions, conflicts, tokens, deltas, vocs

def keep_order(G, colors):
    """Returns a list of the nodes of ``G`` in ordered identically to their id in the graph
    ``G`` is a NetworkX graph. ``colors`` is ignored.
    This is to assign a depth using the nx.coloring.greedy_color function
    """
    return sorted(list(G))

In [12]:
bert_name = "camembert-base"
dataset = load_from_brat('/home/tannier/data/resources/daloux/brat_files')#load_genia_ner()

NEG_ONLY = False

if NEG_ONLY:
    neg_doc_ids = dataset['mentions']['doc_id'].unique()
    neg_docs = dataset['docs'][dataset['docs']['doc_id'].isin(neg_doc_ids)]
    neg_mentions = dataset['mentions'][dataset['mentions']['doc_id'].isin(neg_doc_ids)]
    neg_fragments = dataset['fragments'][dataset['fragments']['doc_id'].isin(neg_doc_ids)]
    
    dataset = Dataset(
        docs=neg_docs,
        mentions=neg_mentions,
        fragments=neg_fragments,
        attributes=dataset['attributes'],
        relations=dataset['relations'],
        comments=dataset['comments'],
    )

docs, sentences, zones, mentions, conflicts, tokens, deltas, vocs = preprocess(
    dataset=dataset,
    max_sentence_length=120,
    bert_name=bert_name,
    ner_labels= ['NEG'],
    unknown_labels="drop",
)
batcher, encoded, ids = make_batcher(docs, sentences, zones, mentions, conflicts, tokens)

Dataset: Dataset(
  (docs):       3790 * ('doc_id', 'text', 'split')
  (mentions):    926 * ('doc_id', 'mention_id', 'label', 'text')
  (fragments):   926 * ('doc_id', 'mention_id', 'fragment_id', 'begin', 'end')
  (attributes):    0 * ('doc_id', 'mention_id', 'attribute_id', 'label', 'value')
  (relations):     0 * ('doc_id', 'relation_id', 'relation_label', 'from_mention_id', 'to_mention_id')
  (comments):      0 * ('doc_id', 'comment_id', 'mention_id', 'comment')
)
Transform texts... done
Splitting into sentences... done
Tokenizing... 



done
Processing zones (overlapping areas)... done
Computing vocabularies...
Will train vocabulary for ner_label
Will train vocabulary for token
Discovered existing vocabulary (32005 entities) for token
Normalized text, with given vocabulary and no unk
Normalized text, with given vocabulary and no unk
Normalized text, with given vocabulary and no unk
Normalized split, with given vocabulary and no unk
Normalized split, with given vocabulary and no unk
Normalized split, with given vocabulary and no unk
done


In [13]:
#all_test_doc_ids = []
#sims = {}
#for i in range(200):
seed_all(1234567+137)

train_batcher = batcher['doc'][batcher['doc']['split']==0]['sentence']
test_batcher = batcher['doc'][batcher['doc']['split']==2]['sentence']

splits = np.zeros(len(train_batcher['doc']), dtype=int)

val_perc = 0.1
splits[np.random.choice(np.arange(len(splits)), size=int(val_perc * len(splits)))] = 1

val_batcher = batcher['sentence'][splits == 1]

# train_val_split = np.random.permutation(len(train_batcher))
# test_batcher = train_batcher[train_val_split[:int(0.1*len(train_val_split))]]['sentence']
# train_batcher = train_batcher[train_val_split[int(0.1*len(train_val_split)):]]['sentence']
sim = ((np.bincount(val_batcher['mention', 'ner_label'], minlength=len(vocs["ner_label"]))/len(val_batcher['mention']) -
np.bincount(train_batcher['mention', 'ner_label'], minlength=len(vocs["ner_label"]))/len(train_batcher['mention']))**2).sum()
print("Similarity (L2 dist) between train and val frequencies:", sim)
print("Frequencies")
#all_test_doc_ids.append((test_doc_ids, sim))
display(pd.DataFrame([
    {"index": "train", **dict(zip(vocs["ner_label"], np.bincount(train_batcher['mention', 'ner_label'], minlength=len(vocs["ner_label"]))/len(train_batcher['mention'])))},
    {"index": "val", **dict(zip(vocs["ner_label"], np.bincount(val_batcher['mention', 'ner_label'], minlength=len(vocs["ner_label"]))/len(val_batcher['mention'])))},
]))

Similarity (L2 dist) between train and val frequencies: 0.0
Frequencies


Unnamed: 0,index,NEG
0,train,1.0
1,val,1.0


In [14]:
# !pip install /home/yoann/these/DEFT/nlstruct/

In [15]:
import traceback
from tqdm import tqdm

from custom_bert import CustomBertModel
from transformers import AdamW, BertModel

from tqdm import tqdm
from scipy.sparse import csr_matrix
from logic_crf import CRF, ConstraintFactor, HintFactor, Indexer

from nlstruct.environment import get_cache
from nlstruct.utils import evaluating, torch_global as tg, freeze
from nlstruct.scoring import compute_metrics, merge_pred_and_gold
from nlstruct.train import make_optimizer_and_schedules, run_optimization, seed_all
from nlstruct.train.schedule import ScaleOnPlateauSchedule, LinearSchedule, ConstantSchedule
    
device = torch.device('cuda:1')
tg.set_device(device)
all_preds = []
histories = []

# To release gpu memory before allocating new parameters for a new model
# A better idea would be to run xp in a function, so that all variables are released when exiting the fn
# but this way we can debug after this cell if something goes wrong
if "all_nets" in globals(): del all_nets
if "optim" in globals(): del optim, 
if "schedules" in globals(): del schedules
if "final_schedule" in globals(): del final_schedule
if "state" in globals(): del state
    
# Hyperparameter search
for layer, hidden_dim, scheme, seed, lr, bert_lr, n_schedules, dropout in [
    (3, 1024 if "large" in bert_name else 768, "bioul", 12,  1e-2, 4e-5, 1, 0.1),
]:
    print(layer, hidden_dim, scheme, seed, lr, bert_lr, n_schedules, dropout)
    #seed = 123456
    seed_all(seed) # /!\ Super important to enable reproducibility

    tag_dim = 1024 if "large" in bert_name else 768#768
    max_grad_norm = 5.
    #lr = 1e-3
    #bert_lr = 6e-5
    tags_lr = bert_lr
    bert_weight_decay = 0.0000
    batch_size = 128
    random_perm=True
    observed_zone_sizes=None
    n_per_zone = "uniform"
    n_freeze = layer + 2 #4
    custom_embeds_layer_index = 19 if "large" in bert_name else 11  #layer#2
    #hidden_dim = 256
    bert_dropout = 0.1
    top_dropout = dropout

    ner_net = NERNet(
            n_tokens=len(vocs["token"]),
            token_dim=1024 if "large" in bert_name else 768,#768,
            n_labels=len(vocs["ner_label"]),
            embeddings=CustomBertModel.from_pretrained(bert_name, custom_embeds_layer_index=custom_embeds_layer_index),

            dropout=top_dropout,
            hidden_dim=hidden_dim,
            tag_scheme=scheme,
            metric='linear') # cosine might be better but looks less stable, oddly,
    all_nets = torch.nn.ModuleDict({
        "ner_net": ner_net,
        "tag_embeddings": torch.nn.Embedding(ner_net.crf.num_tags - 1, tag_dim),
    }).to(device=tg.device)
    del ner_net

    for module in all_nets["ner_net"].embeddings.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = bert_dropout
    all_nets.train()

    # Define the optimizer, maybe multiple learning rate / schedules per parameters groups
    optim, schedules = make_optimizer_and_schedules(all_nets, AdamW, {
        "lr": [
                               (lr,    bert_lr,    bert_lr,    tags_lr),
            (ConstantSchedule, (lr,    bert_lr,    bert_lr,    tags_lr),    15),
            (ConstantSchedule, (lr/4,  bert_lr/4,  bert_lr/4,  tags_lr/4),  15),
            (ConstantSchedule, (lr/16, bert_lr/16, bert_lr/16, tags_lr/16), 10),
            (ConstantSchedule, (lr/64, bert_lr/64, bert_lr/64, tags_lr/64), 10),
        ][:n_schedules+1],
    }, [
        "(?!ner_net\.embeddings\.|tag_embeddings\.).*",
        "ner_net\.embeddings\..*(bias|LayerNorm\.weight)",
        "ner_net\.embeddings\..*(?!bias|LayerNorm\.weight)",
        "tag_embeddings\..*",
    ], num_iter_per_epoch=(len(train_batcher) + 1) / batch_size)
    final_schedule = ScaleOnPlateauSchedule('lr', optim, patience=4, factor=0.25, verbose=True, mode='max')

    # Freeze some bert layers 
    # - n_freeze = 0 to freeze nothing
    # - n_freeze = 1 to freeze word embeddings / position embeddings / ...
    # - n_freeze = 2..13 to freeze the first, second ... 12th layer of bert
    for name, param in all_nets.named_parameters():
        match = re.search("\.(\d+)\.", name)
        if match and int(match.group(1)) < n_freeze - 1:
            freeze([param])
    if n_freeze > 0:
        if hasattr(all_nets['ner_net'].embeddings, 'embeddings'):
            freeze(all_nets['ner_net'].embeddings.embeddings)
        else:
            freeze(all_nets['ner_net'].embeddings)

    with_tqdm = True
    state = {"all_nets": all_nets, "optim": optim, "schedules": schedules, "final_schedule": final_schedule}  # all we need to restart the training from a given epoch

    def run_epoch():
        pred_batches = []
        gold_batches = []

        total_train_ner_loss = 0
        total_train_acc = 0
        total_train_ner_size = 0

        n_mentions = len(train_batcher["mention"])
        n_matched_mentions = 0
        n_target_mentions = 0
        n_observed_mentions = 0

        with tqdm(train_batcher['sentence'].dataloader(batch_size=batch_size, shuffle=True, sparse_sort_on="token_mask", device=device), disable=not with_tqdm) as bar:
            for batch_i, batch in enumerate(bar):
                optim.zero_grad()

                # Shuffle and split mentions in each zone between observed and target
                target_mentions, observed_mentions, zone_target_mentions, target_mask = split_zone_mentions(
                    batch,
                    random_perm=random_perm,
                    observed_zone_sizes=observed_zone_sizes,
                )
                n_target_mentions += len(target_mentions)
                n_observed_mentions += len(observed_mentions)

                # Compute the tokens label tag embeddings of the observed (maybe overlapping) mentions
                feature_tags = all_nets["ner_net"].crf.spans_to_tags(
                    torch.arange(len(observed_mentions), device=observed_mentions.device),
                    batch["mention", "begin"][observed_mentions], 
                    batch["mention", "end"][observed_mentions], 
                    batch["mention", "ner_label"][observed_mentions], 
                    n_tokens=batch["sentence", "token"].shape[1],
                    n_samples=len(observed_mentions),
                )
                tag_mention, tag_positions = feature_tags.nonzero(as_tuple=True)
                tag_sentence = batch["zone", "@sentence_id"][batch["mention", "@zone_id"]][observed_mentions][tag_mention]
                tag_values = feature_tags[tag_mention, tag_positions]
                tag_embeds = torch.zeros(*batch["sentence", "token"].shape[:2], tag_dim, device=tg.device).view(-1, tag_dim).index_add_(
                    dim=0,
                    index=tag_sentence * batch["sentence", "token"].shape[1] + tag_positions, 
                    source=all_nets["tag_embeddings"].weight[tag_values-1]).view(*batch["sentence", "token"].shape[:2], tag_dim)

                ##################################
                #       RUN THE NER MODEL        #
                ##################################
                # Run the model argmax here, we compute tag scores and embeddings
                mask = batch["token_mask"]
                ner_res = all_nets["ner_net"](
                    tokens = batch["token"],
                    mask = mask,
                    tag_embeds = tag_embeds,
                    return_embeddings=True,
                )
                scores = ner_res["scores"]
                embeds = ner_res["embeddings"]

                # Run the linear CRF Viterbi algorithm to compute the most likely sequence
                spans = all_nets["ner_net"].crf.tags_to_spans(all_nets["ner_net"].crf.decode(scores, mask), mask)

                # Save predicted mentions
                pred_batch = Batcher({
                    "mention": {
                        "mention_id": torch.arange(n_mentions, n_mentions+len(spans["span_doc_id"]), device=device),
                        "begin": spans["span_begin"],
                        "end": spans["span_end"],
                        "ner_label": spans["span_label"],
                        "@sentence_id": spans["span_doc_id"],
                    },
                    "sentence": dict(batch["sentence", ["sentence_id", "doc_id"]]),
                    "doc": dict(batch["doc"])}, 
                    check=False)
                pred_batches.append(pred_batch)
                n_mentions += len(spans["span_doc_id"])

                ##################################
                #      NER LOSS COMPUTATION      #
                ##################################
                matched_mentions = select_closest_non_overlapping_gold_mentions(
                    gold_ids=target_mentions,
                    gold_sentence_ids=batch["zone", "@sentence_id"][batch["mention", "@zone_id"]][target_mentions],
                    gold_begins=batch["mention", "begin"][target_mentions],
                    gold_ends=batch["mention", "end"][target_mentions],

                    pred_sentence_ids=spans["span_doc_id"],
                    pred_begins=spans["span_begin"],
                    pred_ends=spans["span_end"],

                    zone_mention_id=batch["zone", "@mention_id"],
                    zone_mask=batch["zone", "mention_mask"],

                    gold_conflicts=batch["mention", "@conflict_mention_id"],
                    gold_conflicts_mask=batch["mention", "conflict_mask"],
                )
                n_matched_mentions += len(matched_mentions)
                gold_batches.append(batch["mention", matched_mentions].sparsify())

                # Compute the tokens label tag of the selected non-overlapping gold mentions to infer from the model
                target_tags = all_nets["ner_net"].crf.spans_to_tags(
                    batch["zone", "@sentence_id"][batch["mention", "@zone_id"][matched_mentions]],
                    batch["mention", "begin"][matched_mentions], 
                    batch["mention", "end"][matched_mentions], 
                    batch["mention", "ner_label"][matched_mentions], 
                    n_tokens=batch["sentence", "token"].shape[1],
                    n_samples=batch["sentence", "token"].shape[0],
                )
                # Run the linear CRF forward algorithm on the tokens to compute the loglikelihood of the targets
                ner_loss = -all_nets["ner_net"].crf(scores, mask, target_tags, reduction="mean")
                total_train_ner_loss += float(ner_loss) * len(batch["sentence"])
                total_train_ner_size += len(batch["sentence"])

                loss = ner_loss

                # Perform optimization step
                loss.backward()
                torch.nn.utils.clip_grad_norm_(all_nets.parameters(), max_grad_norm)
                optim.step()
                for schedule_name, schedule in schedules.items():
                    schedule.step()

        # Compute precision, recall and f1 on train set
        ner_pred = Batcher.concat(pred_batches)
        ner_gold = Batcher.concat(gold_batches)

        train_metrics    = compute_scores(ner_pred, ner_gold, prefix='train_')
        val_metrics     = compute_scores(extract_mentions(val_batcher, all_nets=all_nets), val_batcher, prefix='val_',
            queries={
                "3.1": "ner_label in ['NEG']",
                #"3.2": "ner_label in ['anatomie', 'dose', 'examen', 'mode', 'moment', 'substance', 'traitement', 'valeur']",
            }
                                        )
        # final_schedule.step(val_f1, state["epoch"])

        return \
        {
            "train_ner_loss": total_train_ner_loss / max(total_train_ner_size, 1),
            **train_metrics,
            # **val_metrics,
            **val_metrics,
            "val_macro_f1": val_metrics["val_3.1_f1"], #(val_metrics["val_3.1_f1"] + val_metrics["val_3.2_f1"]) / 2.,
            "n_matched": n_matched_mentions,
            "lr": schedules['lr'].get_val()[0],
        }

    try:
        best, history = run_optimization(
            main_score = "val_f1", # do not earlystop based on validation
            metrics_info=metrics_info,
            max_epoch=200,
            patience=None,
            state=state, 
            cache_policy="all", # only store metrics, not checkpoints
            cache=get_cache("daloux", {"seed": seed, "train_batcher": train_batcher, "val_batcher": None, "random_perm": random_perm, "observed_zone_sizes": observed_zone_sizes, "batch_size": batch_size, "max_grad_norm": max_grad_norm, **state}, loader=torch.load, dumper=torch.save),  # where to store the model (main name + hashed parameters)
            epoch_fn=run_epoch,
            n_save_checkpoints=2,
#             exit_on_score=0.92,
        )
        # histories.append({"layer": layer, "hidden_dim": hidden_dim, "scheme": scheme, "seed": seed, "history": history})
    except Exception as e:
        
        # We catch any exception otherwise some variables (including torch parameters on the gpu) end up being stored globally in sys.last_value, leading to memory errors)
        traceback.print_exc()
        break.
    finally:
        pass
        #del optim, schedules, final_schedule, state

Available CUDA devices 8
Current device cuda:1
3 768 bioul 12 0.01 4e-05 1 0.1
before layer norm
Using cache /home/tannier/data/cache/daloux/2bb9f55995f7f9f3


  warn(f"Entry '{key}' in the state seems to be mutable but has no load_state_dict/state_dict methods. This could lead to unpredictable behaviors.")
100%|██████████| 49/49 [00:29<00:00,  1.66it/s]
  0%|          | 0/49 [00:00<?, ?it/s]



epoch | train_ner_loss | train_f1 | [31mval_f1[0m | val_3.1_f1 | val_macro_f1 | n_matched |       lr |    dur(s)
    1 |         [32m7.2745[0m |   [32m0.0040[0m | [32m0.0000[0m |     [32m0.0000[0m |       [32m0.0000[0m |  356.0000 | 1.00e-02 |   31.1353


100%|██████████| 49/49 [00:17<00:00,  2.86it/s]
  2%|▏         | 1/49 [00:00<00:07,  6.59it/s]

    2 |         [32m2.9660[0m |   [32m0.0537[0m | [31m0.0000[0m |     [31m0.0000[0m |       [31m0.0000[0m |  372.0000 | 1.00e-02 |   18.2243


100%|██████████| 49/49 [00:21<00:00,  2.31it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

    3 |         [32m2.4525[0m |   [32m0.0750[0m | [32m0.0870[0m |     [32m0.0870[0m |       [32m0.0870[0m |  372.0000 | 1.00e-02 |   22.5311


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

    4 |         [32m1.7418[0m |   [32m0.1553[0m | [31m0.0000[0m |     [31m0.0000[0m |       [31m0.0000[0m |  338.0000 | 1.00e-02 |   18.3814


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

    5 |         [32m1.3981[0m |   [32m0.2202[0m | [32m0.3932[0m |     [32m0.3932[0m |       [32m0.3932[0m |  393.0000 | 1.00e-02 |   18.9467


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

    6 |         [32m1.1315[0m |   [32m0.2599[0m | [31m0.3299[0m |     [31m0.3299[0m |       [31m0.3299[0m |  387.0000 | 1.00e-02 |   18.9523


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

    7 |         [32m0.9462[0m |   [32m0.3194[0m | [31m0.3444[0m |     [31m0.3444[0m |       [31m0.3444[0m |  381.0000 | 1.00e-02 |   18.8157


100%|██████████| 49/49 [00:19<00:00,  2.51it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

    8 |         [32m0.7878[0m |   [32m0.3742[0m | [31m0.0000[0m |     [31m0.0000[0m |       [31m0.0000[0m |  350.0000 | 1.00e-02 |   20.6121


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

    9 |         [32m0.7805[0m |   [31m0.3683[0m | [31m0.3575[0m |     [31m0.3575[0m |       [31m0.3575[0m |  356.0000 | 1.00e-02 |   18.7342


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   10 |         [32m0.6293[0m |   [32m0.4261[0m | [32m0.4878[0m |     [32m0.4878[0m |       [32m0.4878[0m |  357.0000 | 1.00e-02 |   18.8671


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   11 |         [32m0.5590[0m |   [32m0.4467[0m | [31m0.3148[0m |     [31m0.3148[0m |       [31m0.3148[0m |  367.0000 | 1.00e-02 |   18.9123


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   12 |         [32m0.4646[0m |   [32m0.5233[0m | [32m0.5806[0m |     [32m0.5806[0m |       [32m0.5806[0m |  366.0000 | 1.00e-02 |   18.8204


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   13 |         [32m0.4132[0m |   [31m0.5166[0m | [32m0.7097[0m |     [32m0.7097[0m |       [32m0.7097[0m |  387.0000 | 1.00e-02 |   18.8515


100%|██████████| 49/49 [00:17<00:00,  2.85it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   14 |         [32m0.4065[0m |   [32m0.5422[0m | [31m0.4158[0m |     [31m0.4158[0m |       [31m0.4158[0m |  389.0000 | 1.00e-02 |   18.4722


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   15 |         [32m0.3730[0m |   [32m0.5479[0m | [31m0.6719[0m |     [31m0.6719[0m |       [31m0.6719[0m |  358.0000 | 1.00e-02 |   18.7271


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   16 |         [32m0.2995[0m |   [32m0.6304[0m | [31m0.5686[0m |     [31m0.5686[0m |       [31m0.5686[0m |  380.0000 | 1.00e-02 |   18.7151


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   17 |         [31m0.3161[0m |   [31m0.5978[0m | [31m0.6667[0m |     [31m0.6667[0m |       [31m0.6667[0m |  340.0000 | 1.00e-02 |   18.6326


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   18 |         [31m0.3320[0m |   [31m0.5481[0m | [31m0.6222[0m |     [31m0.6222[0m |       [31m0.6222[0m |  358.0000 | 1.00e-02 |   18.8332


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   19 |         [31m0.3165[0m |   [31m0.6156[0m | [32m0.7368[0m |     [32m0.7368[0m |       [32m0.7368[0m |  375.0000 | 1.00e-02 |   18.8652


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   20 |         [32m0.2650[0m |   [32m0.6716[0m | [31m0.5735[0m |     [31m0.5735[0m |       [31m0.5735[0m |  357.0000 | 1.00e-02 |   18.5396


100%|██████████| 49/49 [00:17<00:00,  2.87it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   21 |         [32m0.2509[0m |   [32m0.6779[0m | [32m0.7581[0m |     [32m0.7581[0m |       [32m0.7581[0m |  368.0000 | 1.00e-02 |   18.3900


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   22 |         [31m0.2736[0m |   [32m0.6953[0m | [31m0.1429[0m |     [31m0.1429[0m |       [31m0.1429[0m |  353.0000 | 1.00e-02 |   18.9333


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  2%|▏         | 1/49 [00:00<00:07,  6.75it/s]

   23 |         [31m0.2941[0m |   [31m0.6435[0m | [31m0.6885[0m |     [31m0.6885[0m |       [31m0.6885[0m |  373.0000 | 1.00e-02 |   18.9003


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   24 |         [32m0.1920[0m |   [32m0.7570[0m | [32m0.7840[0m |     [32m0.7840[0m |       [32m0.7840[0m |  387.0000 | 1.00e-02 |   18.5841


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   25 |         [32m0.1487[0m |   [31m0.7564[0m | [32m0.8430[0m |     [32m0.8430[0m |       [32m0.8430[0m |  361.0000 | 1.00e-02 |   18.9263


100%|██████████| 49/49 [00:17<00:00,  2.85it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   26 |         [31m0.1557[0m |   [32m0.7870[0m | [31m0.8387[0m |     [31m0.8387[0m |       [31m0.8387[0m |  363.0000 | 1.00e-02 |   18.5177


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   27 |         [32m0.1308[0m |   [32m0.7878[0m | [31m0.3596[0m |     [31m0.3596[0m |       [31m0.3596[0m |  368.0000 | 1.00e-02 |   18.9589


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   28 |         [31m0.1461[0m |   [32m0.8017[0m | [31m0.8095[0m |     [31m0.8095[0m |       [31m0.8095[0m |  356.0000 | 1.00e-02 |   18.6962


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   29 |         [31m0.1710[0m |   [31m0.7947[0m | [32m0.8480[0m |     [32m0.8480[0m |       [32m0.8480[0m |  384.0000 | 1.00e-02 |   18.6477


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   30 |         [32m0.1131[0m |   [32m0.8138[0m | [31m0.6456[0m |     [31m0.6456[0m |       [31m0.6456[0m |  370.0000 | 1.00e-02 |   18.8911


100%|██████████| 49/49 [00:17<00:00,  2.84it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   31 |         [31m0.1147[0m |   [32m0.8376[0m | [32m0.8527[0m |     [32m0.8527[0m |       [32m0.8527[0m |  357.0000 | 1.00e-02 |   18.4835


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   32 |         [31m0.1139[0m |   [32m0.8427[0m | [31m0.8197[0m |     [31m0.8197[0m |       [31m0.8197[0m |  350.0000 | 1.00e-02 |   18.5550


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   33 |         [32m0.0715[0m |   [32m0.9106[0m | [32m0.8769[0m |     [32m0.8769[0m |       [32m0.8769[0m |  369.0000 | 1.00e-02 |   18.7580


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   34 |         [32m0.0613[0m |   [32m0.9188[0m | [31m0.8594[0m |     [31m0.8594[0m |       [31m0.8594[0m |  367.0000 | 1.00e-02 |   18.8538


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  2%|▏         | 1/49 [00:00<00:08,  5.75it/s]

   35 |         [32m0.0585[0m |   [32m0.9290[0m | [31m0.8527[0m |     [31m0.8527[0m |       [31m0.8527[0m |  377.0000 | 1.00e-02 |   18.8187


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   36 |         [32m0.0517[0m |   [32m0.9306[0m | [31m0.8750[0m |     [31m0.8750[0m |       [31m0.8750[0m |  370.0000 | 1.00e-02 |   18.7164


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   37 |         [32m0.0360[0m |   [32m0.9443[0m | [31m0.8594[0m |     [31m0.8594[0m |       [31m0.8594[0m |  379.0000 | 1.00e-02 |   18.8729


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   38 |         [31m0.0419[0m |   [31m0.9313[0m | [32m0.9016[0m |     [32m0.9016[0m |       [32m0.9016[0m |  365.0000 | 1.00e-02 |   18.8376


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   39 |         [31m0.0560[0m |   [31m0.9225[0m | [31m0.8661[0m |     [31m0.8661[0m |       [31m0.8661[0m |  375.0000 | 1.00e-02 |   18.8542


100%|██████████| 49/49 [00:17<00:00,  2.84it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   40 |         [31m0.0423[0m |   [32m0.9573[0m | [31m0.8346[0m |     [31m0.8346[0m |       [31m0.8346[0m |  353.0000 | 1.00e-02 |   18.4941


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   41 |         [31m0.0817[0m |   [31m0.8920[0m | [31m0.8730[0m |     [31m0.8730[0m |       [31m0.8730[0m |  353.0000 | 1.00e-02 |   18.7901


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   42 |         [31m0.0423[0m |   [31m0.9420[0m | [31m0.8730[0m |     [31m0.8730[0m |       [31m0.8730[0m |  349.0000 | 1.00e-02 |   18.5799


100%|██████████| 49/49 [00:17<00:00,  2.84it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   43 |         [32m0.0315[0m |   [32m0.9760[0m | [31m0.8092[0m |     [31m0.8092[0m |       [31m0.8092[0m |  357.0000 | 1.00e-02 |   18.5697


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   44 |         [31m0.0554[0m |   [31m0.9348[0m | [31m0.8485[0m |     [31m0.8485[0m |       [31m0.8485[0m |  370.0000 | 1.00e-02 |   18.6661


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   45 |         [31m0.0604[0m |   [31m0.9167[0m | [32m0.9134[0m |     [32m0.9134[0m |       [32m0.9134[0m |  379.0000 | 1.00e-02 |   18.6541


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   46 |         [31m0.0642[0m |   [31m0.9252[0m | [31m0.8906[0m |     [31m0.8906[0m |       [31m0.8906[0m |  371.0000 | 1.00e-02 |   18.7906


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   47 |         [31m0.0548[0m |   [31m0.9515[0m | [31m0.5909[0m |     [31m0.5909[0m |       [31m0.5909[0m |  364.0000 | 1.00e-02 |   19.1324


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   48 |         [31m0.0319[0m |   [31m0.9622[0m | [31m0.8571[0m |     [31m0.8571[0m |       [31m0.8571[0m |  370.0000 | 1.00e-02 |   18.6999


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   49 |         [31m0.0490[0m |   [31m0.9237[0m | [31m0.8682[0m |     [31m0.8682[0m |       [31m0.8682[0m |  376.0000 | 1.00e-02 |   19.0001


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   50 |         [31m0.0365[0m |   [31m0.9436[0m | [31m0.8906[0m |     [31m0.8906[0m |       [31m0.8906[0m |  381.0000 | 1.00e-02 |   18.8958


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  2%|▏         | 1/49 [00:00<00:07,  6.83it/s]

   51 |         [32m0.0290[0m |   [31m0.9699[0m | [31m0.6243[0m |     [31m0.6243[0m |       [31m0.6243[0m |  384.0000 | 1.00e-02 |   19.1070


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   52 |         [32m0.0173[0m |   [32m0.9825[0m | [31m0.8682[0m |     [31m0.8682[0m |       [31m0.8682[0m |  371.0000 | 1.00e-02 |   18.7216


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   53 |         [31m0.0548[0m |   [31m0.9545[0m | [31m0.8960[0m |     [31m0.8960[0m |       [31m0.8960[0m |  373.0000 | 1.00e-02 |   18.6553


100%|██████████| 49/49 [00:17<00:00,  2.76it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   54 |         [31m0.0407[0m |   [31m0.9570[0m | [31m0.8527[0m |     [31m0.8527[0m |       [31m0.8527[0m |  371.0000 | 1.00e-02 |   19.0017


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   55 |         [31m0.0484[0m |   [31m0.9687[0m | [31m0.8871[0m |     [31m0.8871[0m |       [31m0.8871[0m |  369.0000 | 1.00e-02 |   18.8085


100%|██████████| 49/49 [00:17<00:00,  2.77it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   56 |         [31m0.0309[0m |   [31m0.9733[0m | [31m0.8060[0m |     [31m0.8060[0m |       [31m0.8060[0m |  377.0000 | 1.00e-02 |   18.9744


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   57 |         [31m0.0416[0m |   [31m0.9571[0m | [31m0.8594[0m |     [31m0.8594[0m |       [31m0.8594[0m |  351.0000 | 1.00e-02 |   18.8804


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   58 |         [31m0.0482[0m |   [31m0.9412[0m | [31m0.8750[0m |     [31m0.8750[0m |       [31m0.8750[0m |  356.0000 | 1.00e-02 |   18.9181


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   59 |         [31m0.0658[0m |   [31m0.9188[0m | [31m0.1896[0m |     [31m0.1896[0m |       [31m0.1896[0m |  360.0000 | 1.00e-02 |   18.9419


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   60 |         [31m0.0689[0m |   [31m0.8981[0m | [31m0.7907[0m |     [31m0.7907[0m |       [31m0.7907[0m |  365.0000 | 1.00e-02 |   18.6347


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   61 |         [31m0.0592[0m |   [31m0.8895[0m | [31m0.8594[0m |     [31m0.8594[0m |       [31m0.8594[0m |  362.0000 | 1.00e-02 |   18.6664


100%|██████████| 49/49 [00:17<00:00,  2.77it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   62 |         [31m0.0301[0m |   [31m0.9628[0m | [31m0.8333[0m |     [31m0.8333[0m |       [31m0.8333[0m |  362.0000 | 1.00e-02 |   19.0947


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   63 |         [31m0.0287[0m |   [31m0.9562[0m | [31m0.8154[0m |     [31m0.8154[0m |       [31m0.8154[0m |  391.0000 | 1.00e-02 |   18.8402


100%|██████████| 49/49 [00:17<00:00,  2.76it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   64 |         [31m0.0753[0m |   [31m0.9272[0m | [31m0.8125[0m |     [31m0.8125[0m |       [31m0.8125[0m |  353.0000 | 1.00e-02 |   19.0556


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   65 |         [31m0.0434[0m |   [31m0.9410[0m | [31m0.8594[0m |     [31m0.8594[0m |       [31m0.8594[0m |  357.0000 | 1.00e-02 |   18.7215


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   66 |         [31m0.0209[0m |   [31m0.9792[0m | [31m0.8571[0m |     [31m0.8571[0m |       [31m0.8571[0m |  384.0000 | 1.00e-02 |   18.8217


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   67 |         [31m0.0533[0m |   [31m0.9571[0m | [31m0.8871[0m |     [31m0.8871[0m |       [31m0.8871[0m |  363.0000 | 1.00e-02 |   18.7186


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   68 |         [31m0.0236[0m |   [31m0.9634[0m | [31m0.8594[0m |     [31m0.8594[0m |       [31m0.8594[0m |  357.0000 | 1.00e-02 |   18.9132


100%|██████████| 49/49 [00:17<00:00,  2.76it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   69 |         [32m0.0117[0m |   [31m0.9773[0m | [31m0.8217[0m |     [31m0.8217[0m |       [31m0.8217[0m |  353.0000 | 1.00e-02 |   19.0377


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   70 |         [32m0.0087[0m |   [32m0.9946[0m | [31m0.8346[0m |     [31m0.8346[0m |       [31m0.8346[0m |  371.0000 | 1.00e-02 |   18.7805


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   71 |         [31m0.0159[0m |   [31m0.9777[0m | [31m0.8615[0m |     [31m0.8615[0m |       [31m0.8615[0m |  360.0000 | 1.00e-02 |   18.9696


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   72 |         [31m0.0417[0m |   [31m0.9895[0m | [31m0.8209[0m |     [31m0.8209[0m |       [31m0.8209[0m |  381.0000 | 1.00e-02 |   19.0527


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   73 |         [31m0.0594[0m |   [31m0.9310[0m | [31m0.8504[0m |     [31m0.8504[0m |       [31m0.8504[0m |  373.0000 | 1.00e-02 |   18.7089


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   74 |         [31m0.0414[0m |   [31m0.9404[0m | [31m0.8500[0m |     [31m0.8500[0m |       [31m0.8500[0m |  364.0000 | 1.00e-02 |   18.8012


100%|██████████| 49/49 [00:17<00:00,  2.85it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   75 |         [31m0.0587[0m |   [31m0.9169[0m | [31m0.8837[0m |     [31m0.8837[0m |       [31m0.8837[0m |  387.0000 | 1.00e-02 |   18.5591


100%|██████████| 49/49 [00:17<00:00,  2.77it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   76 |         [31m0.0166[0m |   [31m0.9872[0m | [31m0.8462[0m |     [31m0.8462[0m |       [31m0.8462[0m |  391.0000 | 1.00e-02 |   19.0764


100%|██████████| 49/49 [00:17<00:00,  2.75it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   77 |         [31m0.0126[0m |   [31m0.9872[0m | [31m0.8594[0m |     [31m0.8594[0m |       [31m0.8594[0m |  354.0000 | 1.00e-02 |   19.1151


100%|██████████| 49/49 [00:17<00:00,  2.77it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   78 |         [31m0.0148[0m |   [31m0.9762[0m | [31m0.8618[0m |     [31m0.8618[0m |       [31m0.8618[0m |  378.0000 | 1.00e-02 |   19.0098


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   79 |         [31m0.0307[0m |   [31m0.9545[0m | [31m0.8780[0m |     [31m0.8780[0m |       [31m0.8780[0m |  373.0000 | 1.00e-02 |   18.8582


100%|██████████| 49/49 [00:17<00:00,  2.75it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   80 |         [31m0.0375[0m |   [31m0.9710[0m | [31m0.7857[0m |     [31m0.7857[0m |       [31m0.7857[0m |  365.0000 | 1.00e-02 |   19.3987


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   81 |         [31m0.0437[0m |   [31m0.9571[0m | [31m0.8527[0m |     [31m0.8527[0m |       [31m0.8527[0m |  363.0000 | 1.00e-02 |   18.8439


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   82 |         [31m0.0674[0m |   [31m0.9418[0m | [31m0.8976[0m |     [31m0.8976[0m |       [31m0.8976[0m |  395.0000 | 1.00e-02 |   18.8852


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   83 |         [31m0.0643[0m |   [31m0.9229[0m | [31m0.7862[0m |     [31m0.7862[0m |       [31m0.7862[0m |  375.0000 | 1.00e-02 |   18.9392


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   84 |         [31m0.0551[0m |   [31m0.9525[0m | [31m0.6853[0m |     [31m0.6853[0m |       [31m0.6853[0m |  348.0000 | 1.00e-02 |   18.9536


100%|██████████| 49/49 [00:17<00:00,  2.76it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   85 |         [31m0.0188[0m |   [31m0.9798[0m | [31m0.8547[0m |     [31m0.8547[0m |       [31m0.8547[0m |  371.0000 | 1.00e-02 |   19.0869


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   86 |         [31m0.0471[0m |   [31m0.9781[0m | [31m0.8819[0m |     [31m0.8819[0m |       [31m0.8819[0m |  368.0000 | 1.00e-02 |   18.8490


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   87 |         [31m0.0344[0m |   [31m0.9621[0m | [31m0.8640[0m |     [31m0.8640[0m |       [31m0.8640[0m |  396.0000 | 1.00e-02 |   18.6690


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   88 |         [31m0.0663[0m |   [31m0.9635[0m | [31m0.7172[0m |     [31m0.7172[0m |       [31m0.7172[0m |  370.0000 | 1.00e-02 |   18.9690


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   89 |         [31m0.0553[0m |   [31m0.9441[0m | [31m0.8682[0m |     [31m0.8682[0m |       [31m0.8682[0m |  359.0000 | 1.00e-02 |   18.6929


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   90 |         [31m0.0928[0m |   [31m0.8902[0m | [31m0.7812[0m |     [31m0.7812[0m |       [31m0.7812[0m |  346.0000 | 1.00e-02 |   18.6747


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   91 |         [31m0.0164[0m |   [31m0.9702[0m | [31m0.8906[0m |     [31m0.8906[0m |       [31m0.8906[0m |  352.0000 | 1.00e-02 |   18.9679


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   92 |         [31m0.0234[0m |   [31m0.9858[0m | [31m0.8819[0m |     [31m0.8819[0m |       [31m0.8819[0m |  387.0000 | 1.00e-02 |   18.6910


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   93 |         [31m0.0111[0m |   [31m0.9852[0m | [31m0.8906[0m |     [31m0.8906[0m |       [31m0.8906[0m |  374.0000 | 1.00e-02 |   18.8729


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   94 |         [32m0.0051[0m |   [32m0.9957[0m | [31m0.8819[0m |     [31m0.8819[0m |       [31m0.8819[0m |  351.0000 | 1.00e-02 |   18.7816


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   95 |         [32m0.0047[0m |   [31m0.9918[0m | [31m0.8819[0m |     [31m0.8819[0m |       [31m0.8819[0m |  363.0000 | 1.00e-02 |   18.6105


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   96 |         [31m0.0064[0m |   [31m0.9892[0m | [31m0.8906[0m |     [31m0.8906[0m |       [31m0.8906[0m |  370.0000 | 1.00e-02 |   18.7354


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   97 |         [31m0.0285[0m |   [31m0.9842[0m | [31m0.8640[0m |     [31m0.8640[0m |       [31m0.8640[0m |  378.0000 | 1.00e-02 |   18.8558


100%|██████████| 49/49 [00:17<00:00,  2.77it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   98 |         [31m0.1022[0m |   [31m0.9266[0m | [31m0.8926[0m |     [31m0.8926[0m |       [31m0.8926[0m |  377.0000 | 1.00e-02 |   18.8987


100%|██████████| 49/49 [00:17<00:00,  2.77it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

   99 |         [31m0.0810[0m |   [31m0.9514[0m | [31m0.8837[0m |     [31m0.8837[0m |       [31m0.8837[0m |  351.0000 | 1.00e-02 |   18.9600


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  100 |         [31m0.0301[0m |   [31m0.9621[0m | [31m0.8661[0m |     [31m0.8661[0m |       [31m0.8661[0m |  397.0000 | 1.00e-02 |   18.9663


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  101 |         [31m0.0174[0m |   [31m0.9604[0m | [31m0.8397[0m |     [31m0.8397[0m |       [31m0.8397[0m |  392.0000 | 1.00e-02 |   18.7711


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  102 |         [31m0.0098[0m |   [31m0.9906[0m | [31m0.8800[0m |     [31m0.8800[0m |       [31m0.8800[0m |  373.0000 | 1.00e-02 |   18.6919


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  103 |         [31m0.0124[0m |   [31m0.9879[0m | [31m0.8504[0m |     [31m0.8504[0m |       [31m0.8504[0m |  373.0000 | 1.00e-02 |   18.7732


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  104 |         [31m0.0105[0m |   [31m0.9861[0m | [31m0.8730[0m |     [31m0.8730[0m |       [31m0.8730[0m |  360.0000 | 1.00e-02 |   18.5930


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  105 |         [31m0.0458[0m |   [31m0.9896[0m | [31m0.7051[0m |     [31m0.7051[0m |       [31m0.7051[0m |  384.0000 | 1.00e-02 |   19.1938


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  2%|▏         | 1/49 [00:00<00:07,  6.76it/s]

  106 |         [31m0.0170[0m |   [31m0.9815[0m | [31m0.8819[0m |     [31m0.8819[0m |       [31m0.8819[0m |  379.0000 | 1.00e-02 |   18.8718


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  107 |         [31m0.0531[0m |   [31m0.9608[0m | [31m0.7846[0m |     [31m0.7846[0m |       [31m0.7846[0m |  369.0000 | 1.00e-02 |   18.8167


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  108 |         [31m0.0557[0m |   [31m0.9456[0m | [31m0.8397[0m |     [31m0.8397[0m |       [31m0.8397[0m |  359.0000 | 1.00e-02 |   18.9707


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  109 |         [31m0.0301[0m |   [31m0.9763[0m | [31m0.8750[0m |     [31m0.8750[0m |       [31m0.8750[0m |  381.0000 | 1.00e-02 |   18.8239


100%|██████████| 49/49 [00:17<00:00,  2.84it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  110 |         [31m0.0499[0m |   [31m0.9328[0m | [31m0.8244[0m |     [31m0.8244[0m |       [31m0.8244[0m |  368.0000 | 1.00e-02 |   18.6840


100%|██████████| 49/49 [00:17<00:00,  2.86it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  111 |         [31m0.0291[0m |   [31m0.9749[0m | [31m0.8852[0m |     [31m0.8852[0m |       [31m0.8852[0m |  398.0000 | 1.00e-02 |   18.4145


100%|██████████| 49/49 [00:17<00:00,  2.87it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  112 |         [31m0.0237[0m |   [31m0.9761[0m | [31m0.8750[0m |     [31m0.8750[0m |       [31m0.8750[0m |  356.0000 | 1.00e-02 |   18.4058


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  113 |         [31m0.0619[0m |   [31m0.9511[0m | [31m0.8780[0m |     [31m0.8780[0m |       [31m0.8780[0m |  370.0000 | 1.00e-02 |   18.9053


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  114 |         [31m0.0089[0m |   [31m0.9872[0m | [31m0.8960[0m |     [31m0.8960[0m |       [31m0.8960[0m |  390.0000 | 1.00e-02 |   18.7530


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  2%|▏         | 1/49 [00:00<00:08,  5.90it/s]

  115 |         [31m0.0137[0m |   [31m0.9841[0m | [31m0.8889[0m |     [31m0.8889[0m |       [31m0.8889[0m |  378.0000 | 1.00e-02 |   18.9560


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  116 |         [31m0.0209[0m |   [31m0.9774[0m | [31m0.8710[0m |     [31m0.8710[0m |       [31m0.8710[0m |  376.0000 | 1.00e-02 |   18.7183


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  117 |         [31m0.0115[0m |   [31m0.9759[0m | [31m0.8906[0m |     [31m0.8906[0m |       [31m0.8906[0m |  354.0000 | 1.00e-02 |   18.7429


100%|██████████| 49/49 [00:17<00:00,  2.86it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  118 |         [31m0.0087[0m |   [31m0.9918[0m | [31m0.8800[0m |     [31m0.8800[0m |       [31m0.8800[0m |  364.0000 | 1.00e-02 |   18.5020


100%|██████████| 49/49 [00:17<00:00,  2.79it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  119 |         [31m0.0262[0m |   [31m0.9906[0m | [31m0.8906[0m |     [31m0.8906[0m |       [31m0.8906[0m |  371.0000 | 1.00e-02 |   18.9927


100%|██████████| 49/49 [00:17<00:00,  2.78it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  120 |         [31m0.0286[0m |   [31m0.9592[0m | [31m0.8438[0m |     [31m0.8438[0m |       [31m0.8438[0m |  355.0000 | 1.00e-02 |   19.1269


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  121 |         [31m0.0418[0m |   [31m0.9497[0m | [31m0.8750[0m |     [31m0.8750[0m |       [31m0.8750[0m |  356.0000 | 1.00e-02 |   18.6379


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  122 |         [31m0.0284[0m |   [31m0.9668[0m | [32m0.9180[0m |     [32m0.9180[0m |       [32m0.9180[0m |  346.0000 | 1.00e-02 |   18.6728


100%|██████████| 49/49 [00:17<00:00,  2.85it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  123 |         [31m0.0671[0m |   [31m0.9533[0m | [31m0.8271[0m |     [31m0.8271[0m |       [31m0.8271[0m |  353.0000 | 1.00e-02 |   18.5872


100%|██████████| 49/49 [00:17<00:00,  2.86it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  124 |         [31m0.0507[0m |   [31m0.9565[0m | [31m0.8889[0m |     [31m0.8889[0m |       [31m0.8889[0m |  390.0000 | 1.00e-02 |   18.4440


100%|██████████| 49/49 [00:17<00:00,  2.76it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  125 |         [31m0.1187[0m |   [31m0.9301[0m | [31m0.8682[0m |     [31m0.8682[0m |       [31m0.8682[0m |  358.0000 | 1.00e-02 |   19.0462


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  126 |         [31m0.0106[0m |   [31m0.9720[0m | [31m0.8889[0m |     [31m0.8889[0m |       [31m0.8889[0m |  376.0000 | 1.00e-02 |   18.8803


100%|██████████| 49/49 [00:16<00:00,  2.93it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  127 |         [31m0.0251[0m |   [31m0.9698[0m | [31m0.8682[0m |     [31m0.8682[0m |       [31m0.8682[0m |  349.0000 | 1.00e-02 |   18.0427


100%|██████████| 49/49 [00:16<00:00,  2.97it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  128 |         [31m0.0490[0m |   [31m0.9679[0m | [31m0.8548[0m |     [31m0.8548[0m |       [31m0.8548[0m |  358.0000 | 1.00e-02 |   17.7736


100%|██████████| 49/49 [00:17<00:00,  2.87it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  129 |         [31m0.0490[0m |   [31m0.9472[0m | [31m0.8852[0m |     [31m0.8852[0m |       [31m0.8852[0m |  379.0000 | 1.00e-02 |   18.4865


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  130 |         [31m0.0329[0m |   [31m0.9739[0m | [31m0.8594[0m |     [31m0.8594[0m |       [31m0.8594[0m |  366.0000 | 1.00e-02 |   18.7205


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  131 |         [31m0.0368[0m |   [31m0.9808[0m | [31m0.8819[0m |     [31m0.8819[0m |       [31m0.8819[0m |  367.0000 | 1.00e-02 |   18.8912


100%|██████████| 49/49 [00:17<00:00,  2.84it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  132 |         [31m0.0154[0m |   [31m0.9821[0m | [31m0.8346[0m |     [31m0.8346[0m |       [31m0.8346[0m |  364.0000 | 1.00e-02 |   18.5588


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  133 |         [31m0.0404[0m |   [31m0.9782[0m | [31m0.8906[0m |     [31m0.8906[0m |       [31m0.8906[0m |  343.0000 | 1.00e-02 |   18.7289


100%|██████████| 49/49 [00:16<00:00,  2.95it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  134 |         [31m0.0780[0m |   [31m0.9526[0m | [31m0.8837[0m |     [31m0.8837[0m |       [31m0.8837[0m |  375.0000 | 1.00e-02 |   17.9138


100%|██████████| 49/49 [00:16<00:00,  2.92it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  135 |         [31m0.0728[0m |   [31m0.9487[0m | [31m0.8029[0m |     [31m0.8029[0m |       [31m0.8029[0m |  361.0000 | 1.00e-02 |   18.2606


100%|██████████| 49/49 [00:17<00:00,  2.87it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  136 |         [31m0.0258[0m |   [31m0.9803[0m | [31m0.8640[0m |     [31m0.8640[0m |       [31m0.8640[0m |  357.0000 | 1.00e-02 |   18.4257


100%|██████████| 49/49 [00:17<00:00,  2.86it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  137 |         [31m0.0288[0m |   [31m0.9589[0m | [31m0.8871[0m |     [31m0.8871[0m |       [31m0.8871[0m |  368.0000 | 1.00e-02 |   18.4416


100%|██████████| 49/49 [00:17<00:00,  2.85it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  138 |         [31m0.0146[0m |   [31m0.9736[0m | [31m0.8750[0m |     [31m0.8750[0m |       [31m0.8750[0m |  361.0000 | 1.00e-02 |   18.3936


100%|██████████| 49/49 [00:17<00:00,  2.86it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  139 |         [31m0.0222[0m |   [31m0.9889[0m | [31m0.9048[0m |     [31m0.9048[0m |       [31m0.9048[0m |  360.0000 | 1.00e-02 |   18.5173


100%|██████████| 49/49 [00:16<00:00,  2.90it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  140 |         [31m0.0418[0m |   [31m0.9775[0m | [31m0.8462[0m |     [31m0.8462[0m |       [31m0.8462[0m |  381.0000 | 1.00e-02 |   18.2975


100%|██████████| 49/49 [00:17<00:00,  2.85it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  141 |         [31m0.1261[0m |   [31m0.8988[0m | [31m0.8235[0m |     [31m0.8235[0m |       [31m0.8235[0m |  378.0000 | 1.00e-02 |   18.7627


100%|██████████| 49/49 [00:17<00:00,  2.85it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  142 |         [31m0.0406[0m |   [31m0.9615[0m | [31m0.9120[0m |     [31m0.9120[0m |       [31m0.9120[0m |  365.0000 | 1.00e-02 |   18.5984


100%|██████████| 49/49 [00:17<00:00,  2.86it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  143 |         [31m0.0123[0m |   [31m0.9810[0m | [31m0.8000[0m |     [31m0.8000[0m |       [31m0.8000[0m |  369.0000 | 1.00e-02 |   18.5576


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  144 |         [31m0.0099[0m |   [31m0.9823[0m | [31m0.9062[0m |     [31m0.9062[0m |       [31m0.9062[0m |  340.0000 | 1.00e-02 |   18.6661


100%|██████████| 49/49 [00:16<00:00,  2.89it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  145 |         [31m0.0262[0m |   [31m0.9811[0m | [32m0.9206[0m |     [32m0.9206[0m |       [32m0.9206[0m |  343.0000 | 1.00e-02 |   18.2402


100%|██████████| 49/49 [00:16<00:00,  2.96it/s]
  2%|▏         | 1/49 [00:00<00:09,  5.31it/s]

  146 |         [31m0.0095[0m |   [31m0.9821[0m | [31m0.8906[0m |     [31m0.8906[0m |       [31m0.8906[0m |  391.0000 | 1.00e-02 |   17.8758


100%|██████████| 49/49 [00:17<00:00,  2.87it/s]
  2%|▏         | 1/49 [00:00<00:08,  5.65it/s]

  147 |         [31m0.0172[0m |   [31m0.9806[0m | [31m0.8594[0m |     [31m0.8594[0m |       [31m0.8594[0m |  360.0000 | 1.00e-02 |   18.3478


100%|██████████| 49/49 [00:16<00:00,  2.90it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  148 |         [31m0.1315[0m |   [31m0.9081[0m | [31m0.8906[0m |     [31m0.8906[0m |       [31m0.8906[0m |  363.0000 | 1.00e-02 |   18.2006


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  149 |         [31m0.0381[0m |   [31m0.9615[0m | [31m0.3684[0m |     [31m0.3684[0m |       [31m0.3684[0m |  339.0000 | 1.00e-02 |   19.8607


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  2%|▏         | 1/49 [00:00<00:09,  5.28it/s]

  150 |         [31m0.4968[0m |   [31m0.8431[0m | [31m0.8092[0m |     [31m0.8092[0m |       [31m0.8092[0m |  358.0000 | 1.00e-02 |   18.7966


100%|██████████| 49/49 [00:17<00:00,  2.86it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  151 |         [31m0.0814[0m |   [31m0.8991[0m | [31m0.8480[0m |     [31m0.8480[0m |       [31m0.8480[0m |  370.0000 | 1.00e-02 |   18.4824


100%|██████████| 49/49 [00:16<00:00,  2.90it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  152 |         [31m0.0647[0m |   [31m0.9664[0m | [31m0.8819[0m |     [31m0.8819[0m |       [31m0.8819[0m |  375.0000 | 1.00e-02 |   18.2172


100%|██████████| 49/49 [00:17<00:00,  2.85it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  153 |         [31m0.0236[0m |   [31m0.9568[0m | [31m0.8387[0m |     [31m0.8387[0m |       [31m0.8387[0m |  360.0000 | 1.00e-02 |   18.5158


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  154 |         [31m0.0175[0m |   [31m0.9809[0m | [31m0.8527[0m |     [31m0.8527[0m |       [31m0.8527[0m |  366.0000 | 1.00e-02 |   18.7037


100%|██████████| 49/49 [00:17<00:00,  2.80it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  155 |         [31m0.0108[0m |   [31m0.9841[0m | [31m0.8837[0m |     [31m0.8837[0m |       [31m0.8837[0m |  377.0000 | 1.00e-02 |   19.0002


100%|██████████| 49/49 [00:17<00:00,  2.87it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  156 |         [32m0.0025[0m |   [31m0.9946[0m | [31m0.8837[0m |     [31m0.8837[0m |       [31m0.8837[0m |  373.0000 | 1.00e-02 |   18.3929


100%|██████████| 49/49 [00:16<00:00,  2.89it/s]
  2%|▏         | 1/49 [00:00<00:07,  6.25it/s]

  157 |         [31m0.0611[0m |   [31m0.9421[0m | [31m0.8819[0m |     [31m0.8819[0m |       [31m0.8819[0m |  360.0000 | 1.00e-02 |   18.2479


100%|██████████| 49/49 [00:16<00:00,  2.90it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  158 |         [31m0.0289[0m |   [31m0.9723[0m | [31m0.8594[0m |     [31m0.8594[0m |       [31m0.8594[0m |  381.0000 | 1.00e-02 |   18.2014


100%|██████████| 49/49 [00:16<00:00,  2.92it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  159 |         [31m0.0198[0m |   [31m0.9897[0m | [31m0.8527[0m |     [31m0.8527[0m |       [31m0.8527[0m |  340.0000 | 1.00e-02 |   18.1002


100%|██████████| 49/49 [00:17<00:00,  2.76it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  160 |         [31m0.0359[0m |   [31m0.9665[0m | [31m0.8504[0m |     [31m0.8504[0m |       [31m0.8504[0m |  371.0000 | 1.00e-02 |   19.1304


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  161 |         [31m0.0488[0m |   [31m0.9602[0m | [31m0.8217[0m |     [31m0.8217[0m |       [31m0.8217[0m |  375.0000 | 1.00e-02 |   18.7145


100%|██████████| 49/49 [00:16<00:00,  3.02it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  162 |         [31m0.0708[0m |   [31m0.9508[0m | [31m0.8661[0m |     [31m0.8661[0m |       [31m0.8661[0m |  362.0000 | 1.00e-02 |   17.6178


100%|██████████| 49/49 [00:16<00:00,  3.05it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  163 |         [31m0.0302[0m |   [31m0.9730[0m | [31m0.8594[0m |     [31m0.8594[0m |       [31m0.8594[0m |  354.0000 | 1.00e-02 |   17.4307


100%|██████████| 49/49 [00:16<00:00,  2.93it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  164 |         [31m0.0291[0m |   [31m0.9867[0m | [31m0.8943[0m |     [31m0.8943[0m |       [31m0.8943[0m |  376.0000 | 1.00e-02 |   18.0604


100%|██████████| 49/49 [00:17<00:00,  2.82it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  165 |         [31m0.0332[0m |   [31m0.9719[0m | [31m0.8527[0m |     [31m0.8527[0m |       [31m0.8527[0m |  374.0000 | 1.00e-02 |   18.7871


100%|██████████| 49/49 [00:16<00:00,  2.93it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  166 |         [31m0.0232[0m |   [31m0.9720[0m | [31m0.8594[0m |     [31m0.8594[0m |       [31m0.8594[0m |  376.0000 | 1.00e-02 |   18.1232


100%|██████████| 49/49 [00:16<00:00,  3.04it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  167 |         [31m0.0437[0m |   [31m0.9711[0m | [31m0.8682[0m |     [31m0.8682[0m |       [31m0.8682[0m |  362.0000 | 1.00e-02 |   17.4406


100%|██████████| 49/49 [00:16<00:00,  2.98it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  168 |         [31m0.0292[0m |   [31m0.9790[0m | [31m0.8527[0m |     [31m0.8527[0m |       [31m0.8527[0m |  358.0000 | 1.00e-02 |   17.7988


100%|██████████| 49/49 [00:16<00:00,  3.00it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  169 |         [31m0.5766[0m |   [31m0.8976[0m | [31m0.8871[0m |     [31m0.8871[0m |       [31m0.8871[0m |  368.0000 | 1.00e-02 |   17.5910


100%|██████████| 49/49 [00:16<00:00,  2.95it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  170 |         [31m0.0230[0m |   [31m0.9803[0m | [31m0.8682[0m |     [31m0.8682[0m |       [31m0.8682[0m |  379.0000 | 1.00e-02 |   17.8412


100%|██████████| 49/49 [00:16<00:00,  2.94it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  171 |         [31m0.0430[0m |   [31m0.9770[0m | [31m0.8819[0m |     [31m0.8819[0m |       [31m0.8819[0m |  369.0000 | 1.00e-02 |   18.1036


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  172 |         [31m0.0200[0m |   [31m0.9689[0m | [31m0.8615[0m |     [31m0.8615[0m |       [31m0.8615[0m |  385.0000 | 1.00e-02 |   18.5785


100%|██████████| 49/49 [00:16<00:00,  2.90it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  173 |         [31m0.0563[0m |   [31m0.9730[0m | [31m0.8085[0m |     [31m0.8085[0m |       [31m0.8085[0m |  351.0000 | 1.00e-02 |   18.3420


100%|██████████| 49/49 [00:16<00:00,  2.92it/s]
  2%|▏         | 1/49 [00:00<00:08,  5.97it/s]

  174 |         [31m0.0794[0m |   [31m0.9618[0m | [31m0.8819[0m |     [31m0.8819[0m |       [31m0.8819[0m |  379.0000 | 1.00e-02 |   18.1207


100%|██████████| 49/49 [00:17<00:00,  2.88it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  175 |         [31m0.1035[0m |   [31m0.9243[0m | [31m0.8281[0m |     [31m0.8281[0m |       [31m0.8281[0m |  368.0000 | 1.00e-02 |   18.3669


100%|██████████| 49/49 [00:17<00:00,  2.87it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  176 |         [31m0.0255[0m |   [31m0.9719[0m | [31m0.9062[0m |     [31m0.9062[0m |       [31m0.9062[0m |  372.0000 | 1.00e-02 |   18.3059


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  177 |         [31m0.0341[0m |   [31m0.9927[0m | [31m0.9134[0m |     [31m0.9134[0m |       [31m0.9134[0m |  342.0000 | 1.00e-02 |   18.7758


100%|██████████| 49/49 [00:17<00:00,  2.85it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  178 |         [31m0.0191[0m |   [31m0.9868[0m | [31m0.8730[0m |     [31m0.8730[0m |       [31m0.8730[0m |  380.0000 | 1.00e-02 |   18.6933


100%|██████████| 49/49 [00:17<00:00,  2.84it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  179 |         [31m0.0110[0m |   [31m0.9824[0m | [31m0.9062[0m |     [31m0.9062[0m |       [31m0.9062[0m |  398.0000 | 1.00e-02 |   18.5579


100%|██████████| 49/49 [00:17<00:00,  2.83it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  180 |         [32m0.0016[0m |   [32m0.9959[0m | [31m0.8769[0m |     [31m0.8769[0m |       [31m0.8769[0m |  369.0000 | 1.00e-02 |   18.6444


100%|██████████| 49/49 [00:17<00:00,  2.81it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  181 |         [31m0.0054[0m |   [31m0.9933[0m | [31m0.8837[0m |     [31m0.8837[0m |       [31m0.8837[0m |  377.0000 | 1.00e-02 |   18.7163


100%|██████████| 49/49 [00:17<00:00,  2.84it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  182 |         [31m0.0321[0m |   [31m0.9848[0m | [31m0.8702[0m |     [31m0.8702[0m |       [31m0.8702[0m |  363.0000 | 1.00e-02 |   18.5767


100%|██████████| 49/49 [00:17<00:00,  2.86it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  183 |         [31m0.0315[0m |   [31m0.9638[0m | [31m0.8819[0m |     [31m0.8819[0m |       [31m0.8819[0m |  360.0000 | 1.00e-02 |   18.3741


100%|██████████| 49/49 [00:16<00:00,  2.92it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  184 |         [31m0.0079[0m |   [31m0.9907[0m | [31m0.8527[0m |     [31m0.8527[0m |       [31m0.8527[0m |  377.0000 | 1.00e-02 |   18.0428


100%|██████████| 49/49 [00:17<00:00,  2.85it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  185 |         [31m0.0692[0m |   [31m0.9680[0m | [31m0.6788[0m |     [31m0.6788[0m |       [31m0.6788[0m |  368.0000 | 1.00e-02 |   19.1536


100%|██████████| 49/49 [00:17<00:00,  2.88it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  186 |         [31m0.0526[0m |   [31m0.9513[0m | [31m0.8244[0m |     [31m0.8244[0m |       [31m0.8244[0m |  360.0000 | 1.00e-02 |   18.3618


100%|██████████| 49/49 [00:17<00:00,  2.86it/s]
  2%|▏         | 1/49 [00:00<00:08,  5.90it/s]

  187 |         [31m0.0150[0m |   [31m0.9733[0m | [31m0.7857[0m |     [31m0.7857[0m |       [31m0.7857[0m |  374.0000 | 1.00e-02 |   18.6146


100%|██████████| 49/49 [00:16<00:00,  2.92it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  188 |         [31m0.0141[0m |   [31m0.9849[0m | [31m0.8175[0m |     [31m0.8175[0m |       [31m0.8175[0m |  365.0000 | 1.00e-02 |   18.2992


100%|██████████| 49/49 [00:16<00:00,  2.88it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  189 |         [31m0.0049[0m |   [31m0.9933[0m | [31m0.8889[0m |     [31m0.8889[0m |       [31m0.8889[0m |  370.0000 | 1.00e-02 |   18.3485


100%|██████████| 49/49 [00:17<00:00,  2.88it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  190 |         [31m0.0110[0m |   [32m0.9961[0m | [31m0.8550[0m |     [31m0.8550[0m |       [31m0.8550[0m |  387.0000 | 1.00e-02 |   18.3767


100%|██████████| 49/49 [00:16<00:00,  2.91it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  191 |         [31m0.0428[0m |   [31m0.9907[0m | [31m0.8750[0m |     [31m0.8750[0m |       [31m0.8750[0m |  378.0000 | 1.00e-02 |   18.1335


100%|██████████| 49/49 [00:17<00:00,  2.87it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  192 |         [31m0.1932[0m |   [31m0.8963[0m | [31m0.3895[0m |     [31m0.3895[0m |       [31m0.3895[0m |  365.0000 | 1.00e-02 |   18.5212


100%|██████████| 49/49 [00:16<00:00,  2.89it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  193 |         [31m0.0439[0m |   [31m0.9475[0m | [31m0.7383[0m |     [31m0.7383[0m |       [31m0.7383[0m |  364.0000 | 1.00e-02 |   18.3977


100%|██████████| 49/49 [00:16<00:00,  2.92it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  194 |         [31m0.0585[0m |   [31m0.9613[0m | [31m0.8640[0m |     [31m0.8640[0m |       [31m0.8640[0m |  346.0000 | 1.00e-02 |   18.2029


100%|██████████| 49/49 [00:16<00:00,  2.90it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  195 |         [31m0.0602[0m |   [31m0.9653[0m | [31m0.8548[0m |     [31m0.8548[0m |       [31m0.8548[0m |  375.0000 | 1.00e-02 |   18.1970


100%|██████████| 49/49 [00:16<00:00,  2.93it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  196 |         [31m0.0472[0m |   [31m0.9677[0m | [31m0.8504[0m |     [31m0.8504[0m |       [31m0.8504[0m |  357.0000 | 1.00e-02 |   18.0140


100%|██████████| 49/49 [00:17<00:00,  2.88it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  197 |         [31m0.0054[0m |   [31m0.9847[0m | [31m0.9062[0m |     [31m0.9062[0m |       [31m0.9062[0m |  359.0000 | 1.00e-02 |   18.3283


100%|██████████| 49/49 [00:17<00:00,  2.87it/s]
  0%|          | 0/49 [00:00<?, ?it/s]

  198 |         [31m0.0603[0m |   [31m0.9619[0m | [31m0.8527[0m |     [31m0.8527[0m |       [31m0.8527[0m |  341.0000 | 1.00e-02 |   18.4697


100%|██████████| 49/49 [00:16<00:00,  2.95it/s]
  2%|▏         | 1/49 [00:00<00:08,  5.71it/s]

  199 |         [31m0.0451[0m |   [31m0.9649[0m | [31m0.9032[0m |     [31m0.9032[0m |       [31m0.9032[0m |  354.0000 | 1.00e-02 |   17.8189


100%|██████████| 49/49 [00:16<00:00,  3.02it/s]


  200 |         [31m0.0273[0m |   [31m0.9748[0m | [31m0.8682[0m |     [31m0.8682[0m |       [31m0.8682[0m |  377.0000 | 1.00e-02 |   17.4596
Loading /home/tannier/data/cache/daloux/2bb9f55995f7f9f3/checkpoint-145.pt... Done
Model restored to its best self.state: 145


# Test 
(to be fair, avoid executing this part of the notebook to often, or use the training set instead)

In [15]:
bert_name = "bert-large"
test_dataset=test_dataset#load_genia_ner()#load_from_brat(root.resource("deft_2020/t3-test"), doc_attributes={"source": "real"})
test_docs, test_sentences, test_zones, test_mentions, test_conflicts, test_tokens, test_deltas, _ = preprocess(
    dataset=test_dataset,
    max_sentence_length=120,
    ner_labels=list(vocs["ner_label"]),
    bert_name=bert_name,
    unknown_labels="drop",
    vocabularies=vocs,
)
test_batcher, test_encoded, test_ids = make_batcher(test_docs, test_sentences, test_zones, test_mentions, test_conflicts, test_tokens)

NameError: name 'test_dataset' is not defined

### Extract the test mentions

In [15]:
pred_batcher = extract_mentions(test_batcher, all_nets=all_nets)
gold_batcher = test_batcher

pred=pd.DataFrame(dict(pred_batcher["mention", ["sentence_id", "begin", "end", "ner_label", "mention_id"]]))
gold=pd.DataFrame(dict(gold_batcher["mention", ["@zone_id", "begin", "end", "ner_label", "mention_id"]]))
gold["sentence_id"] = gold_batcher["zone", "sentence_id"][gold["@zone_id"]]
all_preds.append(pred)

print("{: <15} {: <5} {: <5} {: <5}".format("ner_label", "f1", "prec", "rec"))
print("---------------------------------")
for ner_label_idx, ner_label in enumerate(vocs['ner_label']):
    merged = merge_pred_and_gold(
        pred.query(f'ner_label == {ner_label_idx}'), 
        gold.query(f'ner_label == {ner_label_idx}'), 
        span_policy='exact',  # only partially match spans with strict bounds, we could also eval with 'exact' or 'partial'
        on=["sentence_id", ("begin", "end"), "ner_label"], atom_gold_level=["mention_id"], atom_pred_level=["mention_id"])
    precision = merged['tp'].sum() / merged['pred_count'].sum()
    recall = merged['tp'].sum() / merged['gold_count'].sum()
    f1 = 2/(1/precision + 1/recall)
    f1, precision, recall
    print("{: <15} {:.3f} {:.3f} {:.3f}".format(str(ner_label), f1, precision, recall))
agg = compute_metrics(merge_pred_and_gold(
    pred,
    gold,
    span_policy='exact',  # only partially match spans with strict bounds, we could also eval with 'exact' or 'partial'
    on=["sentence_id", ("begin", "end"), "ner_label"], atom_gold_level=["mention_id"], atom_pred_level=["mention_id"]))
print("---------------------------------")
print("{: <15} {:.3f} {:.3f} {:.3f}".format("total", agg["f1"], agg["precision"], agg["recall"]))

ner_label       f1    prec  rec  
---------------------------------
pathologie      0.420 0.593 0.325
sosy            0.524 0.530 0.518
---------------------------------
total           0.514 0.534 0.496
