In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [3]:
import logging
logger = logging.getLogger("nlstruct")
logger.setLevel(logging.INFO)

In [4]:
from nlstruct.text import huggingface_tokenize, regex_sentencize, partition_spans, encode_as_tag, split_into_spans, apply_substitutions, apply_deltas
from nlstruct.dataloaders import load_from_brat, load_genia_ner
from nlstruct.collections import Dataset, Batcher
from nlstruct.utils import merge_with_spans, normalize_vocabularies, factorize_rows, df_to_csr, factorize, torch_global as tg
from nlstruct.modules.crf import BIODecoder, BIOULDecoder
from nlstruct.environment import root, cached
from nlstruct.train import seed_all
from itertools import chain, repeat

import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
import math

import pandas as pd
import numpy as np
import re
import string
from transformers import AutoModel, AutoTokenizer

In [5]:
pd.set_option('display.width', 1000)

In [6]:
import scipy

def select_closest_non_overlapping_gold_mentions(
    gold_ids,
    gold_sentence_ids,
    gold_begins,
    gold_ends,
    
    pred_sentence_ids,
    pred_begins,
    pred_ends,
    
    zone_mention_id,
    zone_mask,
    gold_conflicts,
    gold_conflicts_mask,
):
    """
    Select non overlapping gold mentions (in gold_ids) that are the closest to those found by the model
    Gold mentions are described by gold_* tensors, predicted mentions are described by pred_* tensors
    We select at most one gold mention per zone (zone_mention_id + zone_mask) and each time a gold mention
    is selected, its overlaps are removed according to the gold_conflicts + gold_conflicts_mask tensors
    
    Returns
    -------
    torch.Tensor
        Selected gold ids, included in "gold_ids"
    """

    device = gold_begins.device
    
    if len(gold_ids) == 0:
        return torch.as_tensor([], dtype=torch.long, device=device)

    [rel], [remaining_mask], _ = factorize([zone_mention_id], [zone_mask], reference_values=gold_ids)
    remaining_mentions = zone_mention_id[remaining_mask.any(1)]
    rel = rel[remaining_mask.any(1)]
    remaining_mask = remaining_mask[remaining_mask.any(1)]
        
    keep_mask = torch.zeros(gold_ids.max()+1, device=device, dtype=torch.bool)
    zone_scores = torch.full(remaining_mask.shape, fill_value=-1, device=device, dtype=torch.float)
    
    if len(pred_begins):
        PRED, GOLD = 0, 1
        SENTENCE_ID, BEGIN, END = 0, 1, 2
        p = torch.stack([pred_sentence_ids, pred_begins, pred_ends], dim=0).unsqueeze(GOLD+1)
        g = torch.stack([gold_sentence_ids, gold_begins, gold_ends], dim=0).unsqueeze(PRED+1)

        overlap = (torch.min(p[END], g[END]) - torch.max(p[BEGIN], g[BEGIN])).float().clamp(0)
        overlap = overlap * 2 / (p[END] - p[BEGIN] + g[END] - g[BEGIN])
        score = (p[SENTENCE_ID] == g[SENTENCE_ID]) * overlap
        
        zone_scores = score.max(0).values[rel]
        zone_scores[~remaining_mask] = -1
    else:
        zone_scores[remaining_mask] = 0
    while len(remaining_mask):
        best_indexer = torch.arange(zone_scores.shape[0], device=device), zone_scores.argmax(1)
        best_mentions = remaining_mentions[best_indexer]
        conflicts = gold_conflicts[best_mentions][gold_conflicts_mask[best_mentions]]    
        keep_mask[best_mentions] = True
        remaining_mask[best_indexer] = False
        remaining_mask &= ~(remaining_mentions.unsqueeze(-1) == conflicts).any(-1)
        zone_scores[~remaining_mask] = -1
        zone_scores = zone_scores[remaining_mask.any(1)]
        remaining_mentions = remaining_mentions[remaining_mask.any(1)]
        remaining_mask = remaining_mask[remaining_mask.any(1)]
    return keep_mask.nonzero()[:, 0]

def split_zone_mentions(batch, random_perm=True, observed_zone_sizes=None):
    """
    In a batch, splits mentions between 
    - those that we will consider as being observed
    - and those that we will ask the model to recover
    
    Parameters
    ----------
    random_perm: bool
        Shuffle the mentions before splitting them
    observed_zone_sizes: int
        If not None, selects exactly this number of mentions per zone (=overlapping group of mentions)
        Otherwise, any random number from 0 to the maximum of mentions can be observed in each group
    
    Returns
    -------
    (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor)
        - observed_mentions: flat observed mentions (@mention_id)
        - target_mentions: flat target mentions (@mention_id)
        - zone_target_mentions: mentions to recover, grouped by zone (n_zones * n_mentions_per_zone)
        - target_mask: mask of zone_target_mentions, since every zone can have a different number of picked target mentions
    """
    zone_mention_id = batch["zone", "@mention_id"]
    zone_mention_mask = batch["zone", "mention_mask"]
    n_sentences = len(batch["sentence"])
    device = zone_mention_id.device
    if random_perm:
        perm = torch.rand(zone_mention_id.shape, device=device)
    else:
        perm = torch.zeros(zone_mention_id.shape, device=device, dtype=torch.float)
    perm[~zone_mention_mask] = 2
    perm = perm.argsort(1)

    if observed_zone_sizes is None:
        observed_zone_size = ((zone_mention_mask.sum(-1) + 1) * torch.rand(zone_mention_mask.shape[0], dtype=torch.float, device=device)).long()
    else:
        observed_zone_size = torch.full((zone_mention_mask.shape[0],), fill_value=observed_zone_sizes, device=device, dtype=torch.long)

    # Select mentions that will become features
    zone_observed_mentions = zone_mention_id[torch.arange(perm.shape[0], device=device).unsqueeze(1), perm]
    observed_mask = (torch.arange(zone_mention_mask.shape[1], device=device).unsqueeze(0) < observed_zone_size.unsqueeze(1)) & zone_mention_mask
    observed_mentions = zone_observed_mentions[observed_mask]

    # Select mentions that will be hidden from the model (ie to recover)
    zone_target_mentions = zone_mention_id[torch.arange(perm.shape[0], device=device).unsqueeze(1), perm]
    target_mask = (torch.arange(zone_mention_mask.shape[1], device=device).unsqueeze(0) >= observed_zone_size.unsqueeze(1)) & zone_mention_mask
    zone_target_mentions = zone_target_mentions[target_mask.any(1)]
    target_mask = target_mask[target_mask.any(1)]
    target_mentions = zone_target_mentions[target_mask]
    return target_mentions, observed_mentions, zone_target_mentions, target_mask

def compute_scores(pred_batcher, gold_batcher, queries={}, prefix='val_', verbose=0):
    pred=pd.DataFrame(dict(pred_batcher["mention", ["sentence_id", "begin", "end", "ner_label", "mention_id"]]))
    gold=pd.DataFrame(dict(gold_batcher["mention", ["@zone_id", "begin", "end", "ner_label", "mention_id"]]))
    gold["sentence_id"] = gold_batcher["zone", "sentence_id"][gold["@zone_id"]]

    # Merge on spans and ner_label
    merged = merge_pred_and_gold(
        pred, gold, span_policy='exact',  # only partially match spans with strict bounds, we could also eval with 'exact' or 'partial'
        on=["sentence_id", ("begin", "end"), "ner_label"], atom_gold_level=["mention_id"], atom_pred_level=["mention_id"])
    
    merged["ner_label"] = np.asarray(vocs["ner_label"])[merged["ner_label"]].astype(str)
    metrics = {
        **compute_metrics(merged, prefix=prefix),
        #**compute_metrics(merged.query("ner_label == {}".format(list(vocs["ner_label"]).index(c["sosy"]))), prefix=prefix+"sosy_"),
        #**compute_metrics(merged.query("ner_label == {}".format(list(vocs["ner_label"]).index(c["pathologie"]))), prefix=prefix+"pathologie_"),
    }
    for name, query in queries.items():
        metrics.update(compute_metrics(merged.query(query), prefix=prefix+name+"_"))
    return metrics

In [7]:
# To debug the training, we can just comment the "def run_epoch()" and execute the function body manually without changing anything to it
def extract_mentions(batcher, all_nets, max_depth=10):
    """
    Parameters
    ----------
    batcher: Batcher 
        The batcher containing the text from which we want to extract the mentions (and maybe the gold mentions)
    ner_net: torch.nn.Module
    max_depth: int
        Max number of times we run the model per sample
        
    Returns
    -------
    Batcher
    """
    pred_batches = []
    n_mentions = 0
    ner_net = all_nets["ner_net"]
    tag_embeddings = all_nets["tag_embeddings"]
    with evaluating(all_nets):
        with torch.no_grad():
            for batch_i, batch in enumerate(batcher['sentence'].dataloader(batch_size=batch_size, shuffle=False, sparse_sort_on="token_mask", device=tg.device)):

                tag_embeds = torch.zeros(*batch["sentence", "token"].shape[:2], tag_dim, device=tg.device)
                current_sentences_idx = torch.arange(len(batch), device=tg.device)
                mask = batch["token_mask"]
                tokens = batch["token"]

                for i in range(max_depth):
                    # Run the model argmax here
                    ner_res = ner_net(
                        tokens = tokens,
                        mask = mask,
                        tag_embeds = tag_embeds,
                        return_embeddings=True
                    )

                    # Run the linear CRF Viterbi algorithm to compute the most likely sequence
                    pred_tags = ner_net.crf.decode(ner_res["scores"], mask)
                    spans = ner_net.crf.tags_to_spans(pred_tags, mask)

                    # Save predicted mentions
                    pred_batch = Batcher({
                        "mention": {
                            "mention_id": torch.arange(n_mentions, n_mentions+len(spans["span_doc_id"]), device=tg.device),
                            "begin": spans["span_begin"],
                            "end": spans["span_end"],
                            "ner_label": spans["span_label"],
                            "@sentence_id": current_sentences_idx[spans["span_doc_id"]],
                            "depth": torch.full_like(spans["span_begin"], fill_value=i),
                        },
                        "sentence": dict(batch["sentence", ["sentence_id", "doc_id"]]),
                        "doc": dict(batch["doc"])}, 
                        check=False).sparsify()
                    pred_batches.append(pred_batch)
                    n_mentions += len(spans["span_doc_id"])

                    non_empty_sentences = torch.unique(spans["span_doc_id"])

                    if len(non_empty_sentences) == 0:
                        break

                    # Convert the predicted spans to tags using the same encoding scheme as the one used to decode predicted tags
                    # (We could use a different one: BIODecoder/BIOULDecoder.spans_to_tags is a static function)
                    feature_tags = ner_net.crf.spans_to_tags(
                        torch.arange(len(spans["span_begin"]), device=spans["span_begin"].device),
                        spans["span_begin"], 
                        spans["span_end"],
                        spans["span_label"], 
                        n_tokens=batch["sentence", "token"].shape[1],
                        n_samples=len(spans["span_begin"]),
                    )
                    tag_mention, tag_positions = feature_tags.nonzero(as_tuple=True)
                    tag_sentence = spans["span_doc_id"][tag_mention]
                    tag_values = feature_tags[tag_mention, tag_positions]

                    tag_embeds = tag_embeds.view(-1, tag_dim).index_add_(
                        dim=0,
                        index=tag_sentence * batch["sentence", "token"].shape[1] + tag_positions, 
                        source=tag_embeddings.weight[tag_values-1]).view(len(current_sentences_idx), batch["sentence", "token"].shape[1], tag_dim)[non_empty_sentences]

                    # Compute the tokens label tag embeddings of the observed (maybe overlapping) mentions
                    tokens = tokens[non_empty_sentences]
                    mask = mask[non_empty_sentences]
                    current_sentences_idx = current_sentences_idx[non_empty_sentences]
    return Batcher.concat(pred_batches)

In [8]:
from collections import defaultdict

# Define the training metrics
metrics_info = defaultdict(lambda: False)
flt_format = (5, "{:.4f}".format)
metrics_info.update({
    "train_loss": {"goal": 0, "format": flt_format},
    "train_ner_loss": {"goal": 0, "format": flt_format},
    #"train_recall": {"goal": 1, "format": flt_format, "name": "train_rec"},
    #"train_precision": {"goal": 1, "format": flt_format, "name": "train_prec"},
    "train_f1": {"goal": 1, "format": flt_format, "name": "train_f1"},
    
    "val_loss": {"goal": 0, "format": flt_format},
    "val_ner_loss": {"goal": 0, "format": flt_format},
    "val_label_loss": {"goal": 0, "format": flt_format},
    
    "val_f1": {"goal": 1, "format": flt_format, "name": "val_f1"},
    "val_3.1_f1": {"goal": 1, "format": flt_format, "name": "val_3.1_f1"},
    "val_3.2_f1": {"goal": 1, "format": flt_format, "name": "val_3.2_f1"},
    "val_macro_f1": {"goal": 1, "format": flt_format, "name": "val_macro_f1"},
    "val_sosy_f1": {"goal": 1, "format": flt_format, "name": "val_sosy_f1"},
    "val_pathologie_f1": {"goal": 1, "format": flt_format, "name": "val_patho_f1"},
    
    "duration": {"format": flt_format, "name": "   dur(s)"},
    "rescale": {"format": flt_format},
    "n_depth": {"format": flt_format},
    "n_matched": {"format": flt_format},
    "n_targets": {"format": flt_format},
    "n_observed": {"format": flt_format},
    "total_score_sum": {"format": flt_format},
    "lr": {"format": (5, "{:.2e}".format)},
})

In [9]:
def make_batcher(docs, sentences, zones, mentions, conflicts, tokens):
    """
    Parameters:
    ----------
    docs: pd.DataFrame
    sentences: pd.DataFrame
    zones: pd.DataFrame
    mentions: pd.DataFrame
    conflicts: pd.DataFrame
    tokens: pd.DataFrame
    
    Returns
    -------
    Batcher
    """
    docs = docs.copy()
    sentences = sentences.copy()
    zones = zones.copy()
    mentions = mentions.copy()
    conflicts = conflicts.copy()
    tokens = tokens.copy()
    
    [tokens["token_id"]], unique_token_id = factorize_rows([tokens["token_id"]])
    [mentions["mention_id"], conflicts["mention_id"], conflicts["mention_id_other"]], unique_mention_ids = factorize_rows(
        [mentions[["doc_id", "sentence_id", "mention_id"]], conflicts[["doc_id", "sentence_id", "mention_id"]], conflicts[["doc_id", "sentence_id", "mention_id_other"]]])
    [zones["zone_id"], mentions["zone_id"]], unique_zone_ids = factorize_rows(
        [zones[["doc_id", "sentence_id", "zone_id"]], mentions[["doc_id", "sentence_id", "zone_id"]]])
    [sentences["sentence_id"], zones["sentence_id"], mentions["sentence_id"], tokens["sentence_id"],], unique_sentence_ids = factorize_rows(
        [sentences[["doc_id", "sentence_id"]], zones[["doc_id", "sentence_id"]], mentions[["doc_id", "sentence_id"]], tokens[["doc_id", "sentence_id"]]])
    [docs["doc_id"], sentences["doc_id"], zones["doc_id"], mentions["doc_id"], tokens["doc_id"]], unique_doc_ids = factorize_rows(
        [docs["doc_id"], sentences["doc_id"], zones["doc_id"], mentions["doc_id"], tokens["doc_id"]])
    
    batcher = Batcher({
        "mention": {
            "mention_id": mentions["mention_id"],
            "zone_id": mentions["zone_id"],
            "sentence_id": mentions["sentence_id"],
            "doc_id": mentions["doc_id"],
            "begin": mentions["begin"],
            "end": mentions["end"],
            "ner_label": mentions["ner_label"].cat.codes,
            "conflict_mention_id": df_to_csr(conflicts["mention_id"], conflicts["conflict_idx"], conflicts["mention_id_other"], n_rows=len(unique_mention_ids)),
            "conflict_mask": df_to_csr(conflicts["mention_id"], conflicts["conflict_idx"], n_rows=len(unique_mention_ids)),
        },
        "zone": {
            "zone_id": zones["zone_id"],
            "sentence_id": zones["sentence_id"],
            "doc_id": zones["doc_id"],
            "mention_id": df_to_csr(mentions["zone_id"], mentions["zone_mention_idx"], mentions["mention_id"], n_rows=len(unique_zone_ids)),
            "mention_mask": df_to_csr(mentions["zone_id"], mentions["zone_mention_idx"], n_rows=len(unique_zone_ids)),
        },
        "sentence": {
            "sentence_id": sentences["sentence_id"],
            "doc_id": sentences["doc_id"],
            "mention_id": df_to_csr(mentions["sentence_id"], mentions["mention_idx"], mentions["mention_id"], n_rows=len(unique_sentence_ids)),
            "mention_mask": df_to_csr(mentions["sentence_id"], mentions["mention_idx"], n_rows=len(unique_sentence_ids)),
            "token": df_to_csr(tokens["sentence_id"], tokens["token_idx"], tokens["token"].cat.codes, n_rows=len(unique_sentence_ids)),
            "token_mask": df_to_csr(tokens["sentence_id"], tokens["token_idx"], n_rows=len(unique_sentence_ids)),
            "zone_id": df_to_csr(zones["sentence_id"], zones["zone_idx"], zones["zone_id"], n_rows=len(unique_sentence_ids)),
            "zone_mask": df_to_csr(zones["sentence_id"], zones["zone_idx"], n_rows=len(unique_sentence_ids)),
        },
        "doc": {
            "doc_id": np.arange(len(unique_doc_ids)),
            "sentence_id": df_to_csr(sentences["doc_id"], sentences["sentence_idx"], sentences["sentence_id"], n_rows=len(unique_doc_ids)),
            "sentence_mask": df_to_csr(sentences["doc_id"], sentences["sentence_idx"], n_rows=len(unique_doc_ids)),
            "split": docs["split"].cat.codes,
        }},
        masks={"sentence": {"token": "token_mask", "zone_id": "zone_mask", "mention_id": "mention_mask"}, 
               "mention": {"conflict_mention_id": "conflict_mask"},
               "zone": {"mention_id": "mention_mask"}, 
               "doc": {"sentence_id": "sentence_mask"}}
    )
    return (
        batcher, 
        dict(docs=docs, sentences=sentences, zones=zones, mentions=mentions, tokens=tokens),
        dict(token_id=unique_token_id, mention_id=unique_mention_ids, zone_id=unique_zone_ids, sentence_id=unique_sentence_ids, doc_id=unique_doc_ids)
    )

In [10]:
class NERNet(torch.nn.Module):
    def __init__(self,
                 n_labels,
                 hidden_dim,
                 dropout,
                 n_tokens=None,
                 token_dim=None,
                 embeddings=None,
                 tag_scheme="bio",
                 metric='linear',
                 metric_fc_kwargs=None,
                 ):
        super().__init__()
        if embeddings is not None:
            self.embeddings = embeddings
            if n_tokens is None or token_dim is None:
                if hasattr(embeddings, 'weight'):
                    n_tokens, token_dim = embeddings.weight.shape
                else:
                    n_tokens, token_dim = embeddings.embeddings.weight.shape
        else:
            self.embeddings = torch.nn.Embedding(n_tokens, token_dim) if n_tokens > 0 else None
        assert token_dim is not None, "Provide token_dim or embeddings"
        assert self.embeddings is not None

        dim = (token_dim if n_tokens > 0 else 0)
        self.dropout = torch.nn.Dropout(dropout)
        if tag_scheme == "bio":
            self.crf = BIODecoder(n_labels)
        elif tag_scheme == "bioul":
            self.crf = BIOULDecoder(n_labels)
        else:
            raise Exception()
        if hidden_dim is None:
            hidden_dim = dim
        self.linear = torch.nn.Linear(dim, hidden_dim)
        self.batch_norm = torch.nn.BatchNorm1d(dim)

        n_tags = self.crf.num_tags
        metric_fc_kwargs = metric_fc_kwargs if metric_fc_kwargs is not None else {}
        if metric == "linear":
            self.metric_fc = torch.nn.Linear(dim, n_tags)
        elif metric == "cosine":
            self.metric_fc = CosineSimilarity(dim, n_tags, rescale=rescale, **metric_fc_kwargs)
        elif metric == "ema_cosine":
            self.metric_fc = EMACosineSimilarity(dim, n_tags, rescale=rescale, **metric_fc_kwargs)
        else:
            raise Exception()
    
    def extended_embeddings(self, tokens, mask, **kwargs):
        # Default case here, size <= 512
        # Small ugly check to see if self.embeddings is Bert-like, then we need to pass a mask
        if hasattr(self.embeddings, 'encoder') or hasattr(self.embeddings, 'transformer'):
            return self.embeddings(tokens, mask, **kwargs)[0]
        else:
            return self.embeddings(tokens)

    def forward(self, tokens, mask, tag_embeds=None, return_embeddings=False):
        # Embed the tokens
        scores = None
        # shape: n_batch * sequence * 768
        embeds = self.extended_embeddings(tokens, mask, custom_embeds=tag_embeds)
        state = embeds.masked_fill(~mask.unsqueeze(-1), 0)
        state = torch.relu(self.linear(self.dropout(state)))# + state
        state = self.batch_norm(state.view(-1, state.shape[-1])).view(state.shape)
        scores = self.metric_fc(state)
        return {
            "scores": scores,
            "embeddings": embeds if return_embeddings else None,
        }

In [11]:
class LSTMNERNet(torch.nn.Module):
    def __init__(self,
                 n_labels,
                 hidden_dim,
                 dropout,
                 n_tokens=None,
                 token_dim=None,
                 embeddings=None,
                 tag_scheme="bio",
                 metric='linear',
                 metric_fc_kwargs=None,
                 ):
        super().__init__()
        if embeddings is not None:
            self.embeddings = embeddings
            if n_tokens is None or token_dim is None:
                if hasattr(embeddings, 'weight'):
                    n_tokens, token_dim = embeddings.weight.shape
                else:
                    n_tokens, token_dim = embeddings.embeddings.weight.shape
        else:
            self.embeddings = torch.nn.Embedding(n_tokens, token_dim) if n_tokens > 0 else None
        assert token_dim is not None, "Provide token_dim or embeddings"
        assert self.embeddings is not None

        dim = (token_dim if n_tokens > 0 else 0)
        self.dropout = torch.nn.Dropout(dropout)
        if tag_scheme == "bio":
            self.crf = BIODecoder(n_labels)
        elif tag_scheme == "bioul":
            self.crf = BIOULDecoder(n_labels)
        else:
            raise Exception()
        if hidden_dim is None:
            hidden_dim = dim
        self.linear = torch.nn.Linear(dim, hidden_dim)
        self.batch_norm = torch.nn.BatchNorm1d(dim)
        self.lstm = torch.nn.LSTM(hidden_dim, 
                                  hidden_dim, dropout=dropout, batch_first=True, num_layers=2, bidirectional=True)
            

        n_tags = self.crf.num_tags
        metric_fc_kwargs = metric_fc_kwargs if metric_fc_kwargs is not None else {}
        if metric == "linear":
            self.metric_fc = torch.nn.Linear(hidden_dim*2, n_tags)
        elif metric == "cosine":
            self.metric_fc = CosineSimilarity(dim, n_tags, rescale=rescale, **metric_fc_kwargs)
        elif metric == "ema_cosine":
            self.metric_fc = EMACosineSimilarity(dim, n_tags, rescale=rescale, **metric_fc_kwargs)
        else:
            raise Exception()
    
    def extended_embeddings(self, tokens, mask, **kwargs):
        # Default case here, size <= 512
        # Small ugly check to see if self.embeddings is Bert-like, then we need to pass a mask
        if hasattr(self.embeddings, 'encoder') or hasattr(self.embeddings, 'transformer'):
            return self.embeddings(tokens, mask, **kwargs)[0]
        else:
            return self.embeddings(tokens)

    def forward(self, tokens, mask, tag_embeds=None, return_embeddings=False):
        # Embed the tokens
        scores = None
        # shape: n_batch * sequence * 768
        embeds = self.extended_embeddings(tokens, mask, custom_embeds=tag_embeds)
        state = embeds.masked_fill(~mask.unsqueeze(-1), 0)
        state = torch.relu(self.linear(self.dropout(state)))# + state
        state = self.batch_norm(state.view(-1, state.shape[-1])).view(state.shape)
        
        lstm_state = self.lstm(self.dropout(state))[0]
        state = torch.relu(lstm_state)
        scores = self.metric_fc(state)
        return {
            "scores": scores,
            "embeddings": embeds if return_embeddings else None,
        }

In [12]:
#@cached
def preprocess(
    dataset,
    max_sentence_length,
    bert_name,
    ner_labels=None,
    unknown_labels="drop",
    vocabularies=None,
):
    """
    Parameters
    ----------
        dataset: Dataset
        max_sentence_length: int
            Max number of "words" as defined by the regex in regex_sentencize (so this is not the nb of wordpieces)
        bert_name: str
            bert path/name
        ner_labels: list of str 
            allowed ner labels (to be dropped or filtered)
        unknown_labels: str
            "drop" or "raise"
        vocabularies: dict[str; np.ndarray or list]
        
    Returns
    -------
    (pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, dict[str; np.ndarray or list])
        docs:      ('split', 'text', 'doc_id')
        sentences: ('split', 'doc_id', 'sentence_idx', 'begin', 'end', 'text', 'sentence_id')
        zones:     ('doc_id', 'sentence_id', 'zone_id', 'zone_idx')
        mentions:  ('ner_label', 'doc_id', 'sentence_id', 'mention_id', 'depth', 'zone_id', 'text', 'mention_idx', 'begin', 'end', 'zone_mention_idx')
        conflicts: ('doc_id', 'sentence_id', 'mention_id', 'mention_id_other', 'conflict_idx')
        tokens:    ('split', 'token', 'sentence_id', 'token_id', 'token_idx', 'begin', 'end', 'doc_id', 'sentence_idx')
        deltas:    ('doc_id', 'begin', 'end', 'delta')
        vocs: vocabularies to be reused later for encoding more data or decoding predictions
    """
    print("Dataset:", dataset)

    mentions = dataset["mentions"].rename({"label": "ner_label"}, axis=1)

    
    if ner_labels is not None:
        len_before = len(mentions)
        unknown_ner_labels = list(mentions[~mentions["ner_label"].isin(ner_labels)]["ner_label"].drop_duplicates())
        mentions = mentions[mentions["ner_label"].isin(ner_labels)]
        
        if len(unknown_ner_labels) and unknown_labels == "raise":
            raise Exception(f"Unkown labels in {len_before-len(mentions)} mentions: ", unknown_ner_labels)

    # Check that there is no mention overlap
    mentions = mentions.merge(dataset["fragments"].groupby(["doc_id", "mention_id"], as_index=False, observed=True).agg({"begin": "min", "end": "max"}))

    
    print("Transform texts...", end=" ")
    transformed_docs, deltas = apply_substitutions(
        dataset["docs"], *zip(
            (r"(?<=[{}\\])(?![ ])".format(string.punctuation), r" "),
            (r"(?<![ ])(?=[{}\\])".format(string.punctuation), r" "),
            ("(?<=[a-zA-Z])(?=[0-9])", r" "),
            ("(?<=[0-9])(?=[A-Za-z])", r" "),
        ), apply_unidecode=True)
    transformed_docs = transformed_docs.astype({"text": str})
    transformed_mentions = apply_deltas(mentions, deltas, on=['doc_id'])
    print("done")
    
    print("Splitting into sentences...", end=" ")
    sentences = regex_sentencize(
        transformed_docs, 
        reg_split=r"((?:\s*\n\s*\n)+\s*|(?:(?<=[a-z0-9)]\n)|(?<=[a-z0-9)][ ](?:\.|\n))|(?<=[a-z0-9)][ ][ ](?:\.|\n)))\s*(?=[A-Z]))",
        min_sentence_length=0, max_sentence_length=max_sentence_length,
        # balance_parentheses=True, # default is True
    )
    
    [sentence_mentions], sentences, sentence_to_docs = partition_spans([transformed_mentions], sentences, new_id_name="sentence_id", overlap_policy=False)
#     n_sentences_per_mention = sentence_mentions.assign(count=1).groupby(["doc_id", "mention_id"], as_index=False).agg({"count": "sum", "text": "first", "sentence_id": "last"})
#     if n_sentences_per_mention["count"].max() > 1:
#         display(n_sentences_per_mention.query("count > 1"))
#         display(sentences[sentences["sentence_id"].isin(n_sentences_per_mention.query("count > 1")["sentence_id"])]["text"].tolist())
#         raise Exception("Some mentions could be mapped to more than 1 sentences ({})".format(n_sentences_per_mention["count"].max()))
    if sentence_to_docs is not None:
        sentence_mentions = sentence_mentions.merge(sentence_to_docs)
    
    sentence_mentions = sentence_mentions.assign(mention_idx=0).nlstruct.groupby_assign(["doc_id", "sentence_id"], {"mention_idx": lambda x: tuple(range(len(x)))})
    print("done")
    
    print("Tokenizing...", end=" ")
    tokenizer = AutoTokenizer.from_pretrained(bert_name)
    sentences["text"] = sentences["text"].str.lower()
    tokens = huggingface_tokenize(sentences, tokenizer, doc_id_col="sentence_id")
    sentence_mentions = split_into_spans(sentence_mentions, tokens, pos_col="token_idx", overlap_policy=False)
    print("done")
    
    print("Processing zones (overlapping areas)...", end=" ")
    # Extract overlapping spans
    conflicts = (
        merge_with_spans(sentence_mentions, sentence_mentions, on=["doc_id", "sentence_id", ("begin", "end")], how="outer", suffixes=("", "_other"))
    )

    # ids1, and ids2 make the edges of the overlapping mentions of the same type (see the "ner_label")
    [ids1, ids2], unique_ids = factorize_rows(
        [conflicts[["doc_id", "sentence_id", "mention_id"]], 
         conflicts[["doc_id", "sentence_id", "mention_id_other"]]],
        sentence_mentions.eval("size=(end-begin)").sort_values("size")[["doc_id", "sentence_id", "mention_id"]]
    )
    g = nx.from_scipy_sparse_matrix(df_to_csr(ids1, ids2, n_rows=len(unique_ids), n_cols=len(unique_ids)))
    colored_nodes = np.asarray(list(nx.coloring.greedy_color(g, strategy=keep_order).items()))
    unique_ids['depth'] = colored_nodes[:, 1][colored_nodes[:, 0].argsort()]
    zone_indices, mention_indices = zip(*chain.from_iterable(zip(repeat(zone_idx), zone) for zone_idx, zone in enumerate(nx.connected_components(g))))
    conflicts = conflicts[["doc_id", "sentence_id", "mention_id", "mention_id_other"]].assign(conflict_idx=0).nlstruct.groupby_assign(["doc_id", "sentence_id", "mention_id"], {"conflict_idx": lambda x: tuple(range(len(x)))})

    zone_mentions = pd.DataFrame({
        **unique_ids.iloc[list(mention_indices)],
        "zone_id": zone_indices,
    }).merge(sentence_mentions, on=["doc_id", "sentence_id", "mention_id"]).sort_values(["doc_id", "sentence_id", "zone_id", "mention_id"])
    zone_mentions = zone_mentions.assign(zone_mention_idx=0).nlstruct.groupby_assign(['doc_id', 'sentence_id', 'zone_id'], {"zone_mention_idx": lambda vec: tuple(np.arange(len(vec)))})
    sentence_zones = zone_mentions[["doc_id", "sentence_id", "zone_id"]].drop_duplicates()
    sentence_zones = sentence_zones.assign(zone_idx=0).nlstruct.groupby_assign(['doc_id', 'sentence_id'], {"zone_idx": lambda vec: tuple(np.arange(len(vec)))})
    sentence_mentions = sentence_mentions.merge(zone_mentions.drop_duplicates(["doc_id", "sentence_id", "mention_id", "depth"]))
    print("done")

    print("Computing vocabularies...")
    [transformed_docs, sentences, sentence_zones, zone_mentions, tokens], vocs = normalize_vocabularies(
        [transformed_docs, sentences, sentence_zones, zone_mentions, tokens], 
        vocabularies={"split": ["train", "val", "test"]} if vocabularies is None else vocabularies,
        train_vocabularies={"source": False, "text": False} if vocabularies is None else False,
        verbose=True)
    print("done")
    return transformed_docs, sentences, sentence_zones, zone_mentions, conflicts, tokens, deltas, vocs

def keep_order(G, colors):
    """Returns a list of the nodes of ``G`` in ordered identically to their id in the graph
    ``G`` is a NetworkX graph. ``colors`` is ignored.
    This is to assign a depth using the nx.coloring.greedy_color function
    """
    return sorted(list(G))

In [13]:
bert_name = "camembert-base"
dataset = load_from_brat('/home/ytaille/data/resources/corpus_dalloux/CAS_neg_brat')#load_genia_ner()

NEG_ONLY = False

if NEG_ONLY:
    neg_doc_ids = dataset['mentions']['doc_id'].unique()
    neg_docs = dataset['docs'][dataset['docs']['doc_id'].isin(neg_doc_ids)]
    neg_mentions = dataset['mentions'][dataset['mentions']['doc_id'].isin(neg_doc_ids)]
    neg_fragments = dataset['fragments'][dataset['fragments']['doc_id'].isin(neg_doc_ids)]
    
    dataset = Dataset(
        docs=neg_docs,
        mentions=neg_mentions,
        fragments=neg_fragments,
        attributes=dataset['attributes'],
        relations=dataset['relations'],
        comments=dataset['comments'],
    )

docs, sentences, zones, mentions, conflicts, tokens, deltas, vocs = preprocess(
    dataset=dataset,
    max_sentence_length=120,
    bert_name=bert_name,
    ner_labels= ['NEG'],
    unknown_labels="drop",
)
batcher, encoded, ids = make_batcher(docs, sentences, zones, mentions, conflicts, tokens)

Dataset: Dataset(
  (docs):       3790 * ('doc_id', 'text', 'split')
  (mentions):   1023 * ('doc_id', 'mention_id', 'label', 'text')
  (fragments):  1023 * ('doc_id', 'mention_id', 'fragment_id', 'begin', 'end')
  (attributes):    0 * ('doc_id', 'mention_id', 'attribute_id', 'label', 'value')
  (relations):     0 * ('doc_id', 'relation_id', 'relation_label', 'from_mention_id', 'to_mention_id')
  (comments):      0 * ('doc_id', 'comment_id', 'mention_id', 'comment')
)
Transform texts... done
Splitting into sentences... done
Tokenizing... 



done
Processing zones (overlapping areas)... done
Computing vocabularies...
done


In [14]:
#all_test_doc_ids = []
#sims = {}
#for i in range(200):
seed_all(1234567+137)

train_batcher = batcher['doc'][batcher['doc']['split']==0]['sentence']
val_batcher = batcher['doc'][batcher['doc']['split']==1]['sentence']
# test_batcher = batcher['doc'][batcher['doc']['split']==2]['sentence']

# splits = np.zeros(len(train_batcher['doc']), dtype=int)

# val_perc = 0.1
# splits[np.random.choice(np.arange(len(splits)), size=int(val_perc * len(splits)))] = 1

# val_batcher = batcher['sentence'][splits == 1]

# train_val_split = np.random.permutation(len(train_batcher))
# test_batcher = train_batcher[train_val_split[:int(0.1*len(train_val_split))]]['sentence']
# train_batcher = train_batcher[train_val_split[int(0.1*len(train_val_split)):]]['sentence']
sim = ((np.bincount(val_batcher['mention', 'ner_label'], minlength=len(vocs["ner_label"]))/len(val_batcher['mention']) -
np.bincount(train_batcher['mention', 'ner_label'], minlength=len(vocs["ner_label"]))/len(train_batcher['mention']))**2).sum()
print("Similarity (L2 dist) between train and val frequencies:", sim)
print("Frequebncies")
#all_test_doc_ids.append((test_doc_ids, sim))
display(pd.DataFrame([
    {"index": "train", **dict(zip(vocs["ner_label"], np.bincount(train_batcher['mention', 'ner_label'], minlength=len(vocs["ner_label"]))/len(train_batcher['mention'])))},
    {"index": "val", **dict(zip(vocs["ner_label"], np.bincount(val_batcher['mention', 'ner_label'], minlength=len(vocs["ner_label"]))/len(val_batcher['mention'])))},
]))

Similarity (L2 dist) between train and val frequencies: 0.0
Frequebncies


Unnamed: 0,index,NEG
0,train,1.0
1,val,1.0


In [17]:
# !pip install /home/yoann/these/DEFT/nlstruct/

In [21]:
import traceback
from tqdm import tqdm

from custom_bert import CustomBertModel
from transformers import AdamW, BertModel

from tqdm import tqdm
from scipy.sparse import csr_matrix

from nlstruct.environment import get_cache
from nlstruct.utils import evaluating, torch_global as tg, freeze
from nlstruct.scoring import compute_metrics, merge_pred_and_gold
from nlstruct.train import make_optimizer_and_schedules, run_optimization, seed_all
from nlstruct.train.schedule import ScaleOnPlateauSchedule, LinearSchedule, ConstantSchedule
    
device = torch.device('cuda:0')
tg.set_device(device)
all_preds = []
histories = []

# To release gpu memory before allocating new parameters for a new model
# A better idea would be to run xp in a function, so that all variables are released when exiting the fn
# but this way we can debug after this cell if something goes wrong
if "all_nets" in globals(): del all_nets
if "optim" in globals(): del optim, 
if "schedules" in globals(): del schedules
if "final_schedule" in globals(): del final_schedule
if "state" in globals(): del state
    
# Hyperparameter search
for layer, hidden_dim, scheme, seed, lr, bert_lr, n_schedules, dropout in [
    (3, 1024 if "large" in bert_name else 768, "bioul", 12,  1e-2, 4e-5, 1, 0.1),
]:
    print(layer, hidden_dim, scheme, seed, lr, bert_lr, n_schedules, dropout)
    #seed = 123456
    seed_all(seed) # /!\ Super important to enable reproducibility

    tag_dim = 1024 if "large" in bert_name else 768#768
    max_grad_norm = 5.
    #lr = 1e-3
    #bert_lr = 6e-5
    tags_lr = bert_lr
    bert_weight_decay = 0.0000
    batch_size = 128
    random_perm=True
    observed_zone_sizes=None
    n_per_zone = "uniform"
    n_freeze = layer + 2 #4
    custom_embeds_layer_index = 19 if "large" in bert_name else 11  #layer#2
    #hidden_dim = 256
    bert_dropout = 0.1
    top_dropout = dropout

    ner_net = NERNet(
            n_tokens=len(vocs["token"]),
            token_dim=1024 if "large" in bert_name else 768,#768,
            n_labels=len(vocs["ner_label"]),
            embeddings=CustomBertModel.from_pretrained(bert_name, custom_embeds_layer_index=custom_embeds_layer_index),

            dropout=top_dropout,
            hidden_dim=hidden_dim,
            tag_scheme=scheme,
            metric='linear') # cosine might be better but looks less stable, oddly,
    all_nets = torch.nn.ModuleDict({
        "ner_net": ner_net,
        "tag_embeddings": torch.nn.Embedding(ner_net.crf.num_tags - 1, tag_dim),
    }).to(device=tg.device)
    del ner_net

    for module in all_nets["ner_net"].embeddings.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = bert_dropout
    all_nets.train()

    # Define the optimizer, maybe multiple learning rate / schedules per parameters groups
    optim, schedules = make_optimizer_and_schedules(all_nets, AdamW, {
        "lr": [
                               (lr,    bert_lr,    bert_lr,    tags_lr),
            (ConstantSchedule, (lr,    bert_lr,    bert_lr,    tags_lr),    15),
            (ConstantSchedule, (lr/4,  bert_lr/4,  bert_lr/4,  tags_lr/4),  15),
            (ConstantSchedule, (lr/16, bert_lr/16, bert_lr/16, tags_lr/16), 10),
            (ConstantSchedule, (lr/64, bert_lr/64, bert_lr/64, tags_lr/64), 10),
        ][:n_schedules+1],
    }, [
        "(?!ner_net\.embeddings\.|tag_embeddings\.).*",
        "ner_net\.embeddings\..*(bias|LayerNorm\.weight)",
        "ner_net\.embeddings\..*(?!bias|LayerNorm\.weight)",
        "tag_embeddings\..*",
    ], num_iter_per_epoch=(len(train_batcher) + 1) / batch_size)
    final_schedule = ScaleOnPlateauSchedule('lr', optim, patience=4, factor=0.25, verbose=True, mode='max')

    # Freeze some bert layers 
    # - n_freeze = 0 to freeze nothing
    # - n_freeze = 1 to freeze word embeddings / position embeddings / ...
    # - n_freeze = 2..13 to freeze the first, second ... 12th layer of bert
    for name, param in all_nets.named_parameters():
        match = re.search("\.(\d+)\.", name)
        if match and int(match.group(1)) < n_freeze - 1:
            freeze([param])
    if n_freeze > 0:
        if hasattr(all_nets['ner_net'].embeddings, 'embeddings'):
            freeze(all_nets['ner_net'].embeddings.embeddings)
        else:
            freeze(all_nets['ner_net'].embeddings)

    with_tqdm = True
    state = {"all_nets": all_nets, "optim": optim, "schedules": schedules, "final_schedule": final_schedule}  # all we need to restart the training from a given epoch

    def run_epoch():
        pred_batches = []
        gold_batches = []

        total_train_ner_loss = 0
        total_train_acc = 0
        total_train_ner_size = 0

        n_mentions = len(train_batcher["mention"])
        n_matched_mentions = 0
        n_target_mentions = 0
        n_observed_mentions = 0

        with tqdm(train_batcher['sentence'].dataloader(batch_size=batch_size, shuffle=True, sparse_sort_on="token_mask", device=device), disable=not with_tqdm) as bar:
            for batch_i, batch in enumerate(bar):
                optim.zero_grad()

                # Shuffle and split mentions in each zone between observed and target
                target_mentions, observed_mentions, zone_target_mentions, target_mask = split_zone_mentions(
                    batch,
                    random_perm=random_perm,
                    observed_zone_sizes=observed_zone_sizes,
                )
                n_target_mentions += len(target_mentions)
                n_observed_mentions += len(observed_mentions)

                # Compute the tokens label tag embeddings of the observed (maybe overlapping) mentions
                feature_tags = all_nets["ner_net"].crf.spans_to_tags(
                    torch.arange(len(observed_mentions), device=observed_mentions.device),
                    batch["mention", "begin"][observed_mentions], 
                    batch["mention", "end"][observed_mentions], 
                    batch["mention", "ner_label"][observed_mentions], 
                    n_tokens=batch["sentence", "token"].shape[1],
                    n_samples=len(observed_mentions),
                )
                tag_mention, tag_positions = feature_tags.nonzero(as_tuple=True)
                tag_sentence = batch["zone", "@sentence_id"][batch["mention", "@zone_id"]][observed_mentions][tag_mention]
                tag_values = feature_tags[tag_mention, tag_positions]
                tag_embeds = torch.zeros(*batch["sentence", "token"].shape[:2], tag_dim, device=tg.device).view(-1, tag_dim).index_add_(
                    dim=0,
                    index=tag_sentence * batch["sentence", "token"].shape[1] + tag_positions, 
                    source=all_nets["tag_embeddings"].weight[tag_values-1]).view(*batch["sentence", "token"].shape[:2], tag_dim)

                ##################################
                #       RUN THE NER MODEL        #
                ##################################
                # Run the model argmax here, we compute tag scores and embeddings
                mask = batch["token_mask"]
                ner_res = all_nets["ner_net"](
                    tokens = batch["token"],
                    mask = mask,
                    tag_embeds = tag_embeds,
                    return_embeddings=True,
                )
                scores = ner_res["scores"]
                embeds = ner_res["embeddings"]

                # Run the linear CRF Viterbi algorithm to compute the most likely sequence
                spans = all_nets["ner_net"].crf.tags_to_spans(all_nets["ner_net"].crf.decode(scores, mask), mask)

                # Save predicted mentions
                pred_batch = Batcher({
                    "mention": {
                        "mention_id": torch.arange(n_mentions, n_mentions+len(spans["span_doc_id"]), device=device),
                        "begin": spans["span_begin"],
                        "end": spans["span_end"],
                        "ner_label": spans["span_label"],
                        "@sentence_id": spans["span_doc_id"],
                    },
                    "sentence": dict(batch["sentence", ["sentence_id", "doc_id"]]),
                    "doc": dict(batch["doc"])}, 
                    check=False)
                pred_batches.append(pred_batch)
                n_mentions += len(spans["span_doc_id"])

                ##################################
                #      NER LOSS COMPUTATION      #
                ##################################
                matched_mentions = select_closest_non_overlapping_gold_mentions(
                    gold_ids=target_mentions,
                    gold_sentence_ids=batch["zone", "@sentence_id"][batch["mention", "@zone_id"]][target_mentions],
                    gold_begins=batch["mention", "begin"][target_mentions],
                    gold_ends=batch["mention", "end"][target_mentions],

                    pred_sentence_ids=spans["span_doc_id"],
                    pred_begins=spans["span_begin"],
                    pred_ends=spans["span_end"],

                    zone_mention_id=batch["zone", "@mention_id"],
                    zone_mask=batch["zone", "mention_mask"],

                    gold_conflicts=batch["mention", "@conflict_mention_id"],
                    gold_conflicts_mask=batch["mention", "conflict_mask"],
                )
                n_matched_mentions += len(matched_mentions)
                gold_batches.append(batch["mention", matched_mentions].sparsify())

                # Compute the tokens label tag of the selected non-overlapping gold mentions to infer from the model
                target_tags = all_nets["ner_net"].crf.spans_to_tags(
                    batch["zone", "@sentence_id"][batch["mention", "@zone_id"][matched_mentions]],
                    batch["mention", "begin"][matched_mentions], 
                    batch["mention", "end"][matched_mentions], 
                    batch["mention", "ner_label"][matched_mentions], 
                    n_tokens=batch["sentence", "token"].shape[1],
                    n_samples=batch["sentence", "token"].shape[0],
                )
                # Run the linear CRF forward algorithm on the tokens to compute the loglikelihood of the targets
                ner_loss = -all_nets["ner_net"].crf(scores, mask, target_tags, reduction="mean")
                total_train_ner_loss += float(ner_loss) * len(batch["sentence"])
                total_train_ner_size += len(batch["sentence"])

                loss = ner_loss

                # Perform optimization step
                loss.backward()
                torch.nn.utils.clip_grad_norm_(all_nets.parameters(), max_grad_norm)
                optim.step()
                for schedule_name, schedule in schedules.items():
                    schedule.step()

        # Compute precision, recall and f1 on train set
        ner_pred = Batcher.concat(pred_batches)
        ner_gold = Batcher.concat(gold_batches)

        train_metrics    = compute_scores(ner_pred, ner_gold, prefix='train_')
        val_metrics     = compute_scores(extract_mentions(val_batcher, all_nets=all_nets), val_batcher, prefix='val_',
            queries={
                "3.1": "ner_label in ['NEG']",
                #"3.2": "ner_label in ['anatomie', 'dose', 'examen', 'mode', 'moment', 'substance', 'traitement', 'valeur']",
            }
                                        )
        # final_schedule.step(val_f1, state["epoch"])

        return \
        {
            "train_ner_loss": total_train_ner_loss / max(total_train_ner_size, 1),
            **train_metrics,
            # **val_metrics,
            **val_metrics,
            "val_macro_f1": val_metrics["val_3.1_f1"], #(val_metrics["val_3.1_f1"] + val_metrics["val_3.2_f1"]) / 2.,
            "n_matched": n_matched_mentions,
            "lr": schedules['lr'].get_val()[0],
        }

    try:
        best, history = run_optimization(
            main_score = "val_f1", # do not earlystop based on validation
            metrics_info=metrics_info,
            max_epoch=200,
            patience=None,
            state=state, 
            cache_policy="all", # only store metrics, not checkpoints
            cache=get_cache("daloux", {"seed": seed, "train_batcher": train_batcher, "val_batcher": None, "random_perm": random_perm, "observed_zone_sizes": observed_zone_sizes, "batch_size": batch_size, "max_grad_norm": max_grad_norm, **state}, loader=torch.load, dumper=torch.save),  # where to store the model (main name + hashed parameters)
            epoch_fn=run_epoch,
            n_save_checkpoints=2,
#             exit_on_score=0.92,
        )
        # histories.append({"layer": layer, "hidden_dim": hidden_dim, "scheme": scheme, "seed": seed, "history": history})
    except Exception as e:
        
        # We catch any exception otherwise some variables (including torch parameters on the gpu) end up being stored globally in sys.last_value, leading to memory errors)
        traceback.print_exc()
        break
    finally:
        pass
        #del optim, schedules, final_schedule, state

3 768 bioul 12 0.01 4e-05 1 0.1
before layer norm


  warn(f"Entry '{key}' in the state seems to be mutable but has no load_state_dict/state_dict methods. This could lead to unpredictable behaviors.")
100%|██████████| 24/24 [00:07<00:00,  3.25it/s]
INFO:nlstruct:epoch | train_ner_loss | train_f1 | [31mval_f1[0m | val_3.1_f1 | val_macro_f1 | n_matched |       lr |    dur(s)
INFO:nlstruct:    1 |        [32m13.4219[0m |   [32m0.0000[0m | [32m0.0009[0m |     [32m0.0009[0m |       [32m0.0009[0m |  379.0000 | 1.00e-02 |   15.2738
100%|██████████| 24/24 [00:07<00:00,  3.25it/s]
INFO:nlstruct:    2 |         [32m6.1301[0m |   [32m0.0042[0m | [31m0.0000[0m |     [31m0.0000[0m |       [31m0.0000[0m |  370.0000 | 1.00e-02 |   14.7829
100%|██████████| 24/24 [00:07<00:00,  3.19it/s]
INFO:nlstruct:    3 |         [32m5.2101[0m |   [32m0.0106[0m | [31m0.0000[0m |     [31m0.0000[0m |       [31m0.0000[0m |  394.0000 | 1.00e-02 |    8.5906
100%|██████████| 24/24 [00:06<00:00,  3.45it/s]
INFO:nlstruct:    4 |         [32m

# Test 
(to be fair, avoid executing this part of the notebook to often, or use the training set instead)

In [15]:
bert_name = "bert-large"
test_dataset=test_dataset#load_genia_ner()#load_from_brat(root.resource("deft_2020/t3-test"), doc_attributes={"source": "real"})
test_docs, test_sentences, test_zones, test_mentions, test_conflicts, test_tokens, test_deltas, _ = preprocess(
    dataset=test_dataset,
    max_sentence_length=120,
    ner_labels=list(vocs["ner_label"]),
    bert_name=bert_name,
    unknown_labels="drop",
    vocabularies=vocs,
)
test_batcher, test_encoded, test_ids = make_batcher(test_docs, test_sentences, test_zones, test_mentions, test_conflicts, test_tokens)

NameError: name 'test_dataset' is not defined

### Extract the test mentions

In [15]:
pred_batcher = extract_mentions(test_batcher, all_nets=all_nets)
gold_batcher = test_batcher

pred=pd.DataFrame(dict(pred_batcher["mention", ["sentence_id", "begin", "end", "ner_label", "mention_id"]]))
gold=pd.DataFrame(dict(gold_batcher["mention", ["@zone_id", "begin", "end", "ner_label", "mention_id"]]))
gold["sentence_id"] = gold_batcher["zone", "sentence_id"][gold["@zone_id"]]
all_preds.append(pred)

print("{: <15} {: <5} {: <5} {: <5}".format("ner_label", "f1", "prec", "rec"))
print("---------------------------------")
for ner_label_idx, ner_label in enumerate(vocs['ner_label']):
    merged = merge_pred_and_gold(
        pred.query(f'ner_label == {ner_label_idx}'), 
        gold.query(f'ner_label == {ner_label_idx}'), 
        span_policy='exact',  # only partially match spans with strict bounds, we could also eval with 'exact' or 'partial'
        on=["sentence_id", ("begin", "end"), "ner_label"], atom_gold_level=["mention_id"], atom_pred_level=["mention_id"])
    precision = merged['tp'].sum() / merged['pred_count'].sum()
    recall = merged['tp'].sum() / merged['gold_count'].sum()
    f1 = 2/(1/precision + 1/recall)
    f1, precision, recall
    print("{: <15} {:.3f} {:.3f} {:.3f}".format(str(ner_label), f1, precision, recall))
agg = compute_metrics(merge_pred_and_gold(
    pred,
    gold,
    span_policy='exact',  # only partially match spans with strict bounds, we could also eval with 'exact' or 'partial'
    on=["sentence_id", ("begin", "end"), "ner_label"], atom_gold_level=["mention_id"], atom_pred_level=["mention_id"]))
print("---------------------------------")
print("{: <15} {:.3f} {:.3f} {:.3f}".format("total", agg["f1"], agg["precision"], agg["recall"]))

ner_label       f1    prec  rec  
---------------------------------
pathologie      0.420 0.593 0.325
sosy            0.524 0.530 0.518
---------------------------------
total           0.514 0.534 0.496
