In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [3]:
from nlstruct.text import huggingface_tokenize, regex_sentencize, partition_spans, encode_as_tag, split_into_spans, apply_substitutions, apply_deltas
from nlstruct.dataloaders import *
from nlstruct.collections import Dataset, Batcher
from nlstruct.utils import merge_with_spans, normalize_vocabularies, factorize_rows, df_to_csr, factorize, torch_global as tg
from nlstruct.modules.crf import BIODecoder, BIOULDecoder
from nlstruct.environment import root, cached
from nlstruct.train import seed_all
from itertools import chain, repeat

import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
import math

import pandas as pd
import numpy as np
import re
import string
from transformers import AutoModel, AutoTokenizer

In [4]:
import logging

logger = logging.getLogger("nlstruct")
logger.setLevel(logging.DEBUG)

In [5]:
pd.set_option('display.width', 1000)

In [6]:
# To debug the training, we can just comment the "def run_epoch()" and execute the function body manually without changing anything to it
def extract_mentions(batcher, all_nets, max_depth=10):
    """
    Parameters
    ----------
    batcher: Batcher 
        The batcher containing the text from which we want to extract the mentions (and maybe the gold mentions)
    ner_net: torch.nn.Module
    max_depth: int
        Max number of times we run the model per sample
        
    Returns
    -------
    Batcher
    """
    pred_batches = []
    n_mentions = 0
    ner_net = all_nets["ner_net"]
    with evaluating(all_nets):
        with torch.no_grad():
            for batch_i, batch in enumerate(batcher['sentence'].dataloader(batch_size=batch_size, shuffle=False, sparse_sort_on="token_mask", device=tg.device)):

                mask = batch["token_mask"]

                res = all_nets["ner_net"].forward(
                    tokens=batch["token"],
                    mask=mask,
                    return_loss = False,
                )

                spans = res['sampled_spans']

                pred_batch = Batcher({
                    "mention": {
                        "mention_id": torch.arange(n_mentions, n_mentions+len(spans['span_doc_id']), device=device),
                        "begin": spans["span_begin"],
                        "end": spans["span_end"],
                        "ner_label": spans["span_label"],
                        "@sentence_id": spans["span_doc_id"],
                    },
                    "sentence": dict(batch["sentence", ["sentence_id", "doc_id"]]),
                    "doc": dict(batch["doc"])}, 
                    check=False)

                pred_batches.append(pred_batch)
            
    return Batcher.concat(pred_batches)

In [7]:
from collections import defaultdict

# Define the training metrics
metrics_info = defaultdict(lambda: False)
flt_format = (5, "{:.4f}".format)
metrics_info.update({
    "train_loss": {"goal": 0, "format": flt_format},
    "train_ner_loss": {"goal": 0, "format": flt_format},
    #"train_recall": {"goal": 1, "format": flt_format, "name": "train_rec"},
    #"train_precision": {"goal": 1, "format": flt_format, "name": "train_prec"},
    "train_f1": {"goal": 1, "format": flt_format, "name": "train_f1"},
    
    "val_loss": {"goal": 0, "format": flt_format},
    "val_ner_loss": {"goal": 0, "format": flt_format},
    "val_label_loss": {"goal": 0, "format": flt_format},
    
    "val_f1": {"goal": 1, "format": flt_format, "name": "val_f1"},
    "val_3.1_f1": {"goal": 1, "format": flt_format, "name": "val_3.1_f1"},
    "val_3.2_f1": {"goal": 1, "format": flt_format, "name": "val_3.2_f1"},
    "val_macro_f1": {"goal": 1, "format": flt_format, "name": "val_macro_f1"},
    "val_sosy_f1": {"goal": 1, "format": flt_format, "name": "val_sosy_f1"},
    "val_pathologie_f1": {"goal": 1, "format": flt_format, "name": "val_patho_f1"},
    
    "duration": {"format": flt_format, "name": "   dur(s)"},
    "rescale": {"format": flt_format},
    "n_depth": {"format": flt_format},
    "n_matched": {"format": flt_format},
    "n_targets": {"format": flt_format},
    "n_observed": {"format": flt_format},
    "total_score_sum": {"format": flt_format},
    "lr": {"format": (5, "{:.2e}".format)},
})

In [8]:
from nlstruct.utils import encode_ids

In [9]:
#@cached
def preprocess_train(
    dataset,
    max_sentence_length,
    bert_name,
    ner_labels=None,
    unknown_labels="drop",
    vocabularies=None,
    frag_merge=True,
):
    """
    Parameters
    ----------
        dataset: Dataset
        max_sentence_length: int
            Max number of "words" as defined by the regex in regex_sentencize (so this is not the nb of wordpieces)
        bert_name: str
            bert path/name
        ner_labels: list of str 
            allowed ner labels (to be dropped or filtered)
        unknown_labels: str
            "drop" or "raise"
        vocabularies: dict[str; np.ndarray or list]
    Returns
    -------
    (pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, dict[str; np.ndarray or list])
        docs:      ('split', 'text', 'doc_id')
        sentences: ('split', 'doc_id', 'sentence_idx', 'begin', 'end', 'text', 'sentence_id')
        mentions:  ('ner_label', 'doc_id', 'sentence_id', 'mention_id', 'depth', 'text', 'mention_idx', 'begin', 'end')
        tokens:    ('split', 'token', 'sentence_id', 'token_id', 'token_idx', 'begin', 'end', 'doc_id', 'sentence_idx')
        deltas:    ('doc_id', 'begin', 'end', 'delta')
        vocs: vocabularies to be reused later for encoding more data or decoding predictions
    """
    print("Dataset:", dataset)
    mentions = dataset["mentions"].rename({"label": "ner_label"}, axis=1)
    if ner_labels is not None:
        len_before = len(mentions)
        unknown_ner_labels = list(mentions[~mentions["ner_label"].isin(ner_labels)]["ner_label"].drop_duplicates())
        mentions = mentions[mentions["ner_label"].isin(ner_labels)]
        if len(unknown_ner_labels) and unknown_labels == "raise":
            raise Exception(f"Unkown labels in {len_before-len(mentions)} mentions: ", unknown_ner_labels)
    # Check that there is no mention overlap
    if frag_merge:
        mentions = mentions.merge(dataset["fragments"].groupby(["doc_id", "mention_id"], as_index=False, observed=True).agg({"begin": "min", "end": "max"}))
    print("Transform texts...", end=" ")
    docs, deltas = apply_substitutions(
        dataset["docs"], *zip(
            #(r"(?<=[{}\\])(?![ ])".format(string.punctuation), r" "),
            #(r"(?<![ ])(?=[{}\\])".format(string.punctuation), r" "),
            ("(?<=[a-zA-Z])(?=[0-9])", r" "),
            ("(?<=[0-9])(?=[A-Za-z])", r" "),
        ), apply_unidecode=True)
    docs = docs.astype({"text": str})
    transformed_mentions = apply_deltas(mentions, deltas, on=['doc_id'])
    print("done")
    
    print("Splitting into sentences...", end=" ")
    sentences = regex_sentencize(
        docs, 
        reg_split=r"(?<=[.])(\s*\n+)|(?=, [0-9]\))",
        min_sentence_length=None, max_sentence_length=max_sentence_length,
        balance_parentheses=False,
        # balance_parentheses=True, # default is True
    )
    [mentions], sentences, sentence_to_docs = partition_spans([transformed_mentions], sentences, new_id_name="sentence_id", overlap_policy=False)
    
    mentions = mentions[(mentions['begin']>0) & (mentions['end']>0)]
    
    n_sentences_per_mention = mentions.assign(count=1).groupby(["doc_id", "mention_id"], as_index=False).agg({"count": "sum", "text": tuple, "sentence_id": tuple})
    if n_sentences_per_mention["count"].max() > 1:
        display(n_sentences_per_mention.query("count > 1"))
        display(sentences[sentences["sentence_id"].isin(n_sentences_per_mention.query("count > 1")["sentence_id"].explode())]["text"].tolist())
        raise Exception("Some mentions could be mapped to more than 1 sentences ({})".format(n_sentences_per_mention["count"].max()))
    if sentence_to_docs is not None:
        mentions = mentions.merge(sentence_to_docs)
    mentions = mentions.assign(mention_idx=0).nlstruct.groupby_assign(["doc_id", "sentence_id"], {"mention_idx": lambda x: tuple(range(len(x)))})
    print("done")
    
    print("Tokenizing...", end=" ")
    tokenizer = AutoTokenizer.from_pretrained(bert_name)
    sentences["text"] = sentences["text"].str.lower()
    tokens = huggingface_tokenize(sentences, tokenizer, doc_id_col="sentence_id")
    
    print(tokens)
    
    mentions = split_into_spans(mentions, tokens, pos_col="token_idx", overlap_policy=False)
    print("done")
    
    print("Processing nestings (overlapping areas)...", end=" ")
    # Extract overlapping spans
    conflicts = (
        merge_with_spans(mentions, mentions, on=["doc_id", "sentence_id", ("begin", "end")], how="outer", suffixes=("", "_other"))
    )
    # ids1, and ids2 make the edges of the overlapping mentions of the same type (see the "ner_label")
    [ids1, ids2], unique_ids = factorize_rows(
        [conflicts[["doc_id", "sentence_id", "mention_id"]], 
         conflicts[["doc_id", "sentence_id", "mention_id_other"]]],
        mentions.eval("size=(end-begin)").sort_values("size")[["doc_id", "sentence_id", "mention_id"]]
    )
    g = nx.from_scipy_sparse_matrix(df_to_csr(ids1, ids2, n_rows=len(unique_ids), n_cols=len(unique_ids)))
    colored_nodes = np.asarray(list(nx.coloring.greedy_color(g, strategy=keep_order).items()))
    unique_ids['depth'] = colored_nodes[:, 1][colored_nodes[:, 0].argsort()]
    mentions = mentions.merge(unique_ids)
    print("done")
    
    print("Computing vocabularies...")
    [docs, sentences, mentions, tokens], vocs = normalize_vocabularies(
        [docs, sentences, mentions, tokens], 
        vocabularies={"split": ["train", "val", "test"]} if vocabularies is None else vocabularies,
        train_vocabularies={"source": False, "text": False} if vocabularies is None else False,
        verbose=True)
    print("done")
    
    prep = Dataset(docs=docs, sentences=sentences, mentions=mentions, tokens=tokens).copy()
    
    unique_mention_ids = encode_ids([mentions], ("doc_id", "sentence_id", "mention_id"))
    unique_sentence_ids = encode_ids([sentences, mentions, tokens], ("doc_id", "sentence_id"))
    unique_doc_ids = encode_ids([docs, sentences, mentions, tokens], ("doc_id",))
    
    batcher = Batcher({
        "mention": {
            "mention_id": mentions["mention_id"],
            "sentence_id": mentions["sentence_id"],
            "doc_id": mentions["doc_id"],
            "begin": mentions["begin"],
            "end": mentions["end"],
            "depth": mentions["depth"],
            "ner_label": mentions["ner_label"].cat.codes,
        },
        "sentence": {
            "sentence_id": sentences["sentence_id"],
            "doc_id": sentences["doc_id"],
            "mention_id": df_to_csr(mentions["sentence_id"], mentions["mention_idx"], mentions["mention_id"], n_rows=len(unique_sentence_ids)),
            "mention_mask": df_to_csr(mentions["sentence_id"], mentions["mention_idx"], n_rows=len(unique_sentence_ids)),
            "token": df_to_csr(tokens["sentence_id"], tokens["token_idx"], tokens["token"].cat.codes, n_rows=len(unique_sentence_ids)),
            "token_mask": df_to_csr(tokens["sentence_id"], tokens["token_idx"], n_rows=len(unique_sentence_ids)),
        },
        "doc": {
            "doc_id": np.arange(len(unique_doc_ids)),
            "sentence_id": df_to_csr(sentences["doc_id"], sentences["sentence_idx"], sentences["sentence_id"], n_rows=len(unique_doc_ids)),
            "sentence_mask": df_to_csr(sentences["doc_id"], sentences["sentence_idx"], n_rows=len(unique_doc_ids)),
            "split": docs["split"].cat.codes,
        }},
        masks={"sentence": {"token": "token_mask", "mention_id": "mention_mask"}, 
               "doc": {"sentence_id": "sentence_mask"}}
    )
    
    return batcher, prep, deltas, vocs

def keep_order(G, colors):
    """Returns a list of the nodes of ``G`` in ordered identically to their id in the graph
    ``G`` is a NetworkX graph. ``colors`` is ignored.
    This is to assign a depth using the nx.coloring.greedy_color function
    """
    return sorted(list(G))

In [10]:
bert_name = "bert-base-cased"
dataset = load_ncbi_disease()

docs = dataset['docs']

keep_n_first = None

if keep_n_first:
    docs = docs[:keep_n_first]
    first_ids = docs['doc_id']
    
    first_mentions = dataset["mentions"].loc[dataset["mentions"]['doc_id'].isin(first_ids)]
    first_fragments = dataset["fragments"].loc[dataset["fragments"]['doc_id'].isin(first_ids)]
    first_attributes = dataset["attributes"].loc[dataset["attributes"]['doc_id'].isin(first_ids)]
    
    dataset["mentions"] = first_mentions
    dataset["fragments"] = first_fragments
    dataset["attributes"] = first_attributes
    
dataset['mentions'] = dataset['mentions'].merge(dataset["fragments"].groupby(["doc_id", "mention_id"], as_index=False, observed=True).agg({"begin": "min", "end": "max"}))

merged = dataset['mentions'].merge(dataset['labels'], on=['doc_id', 'mention_id'])
merged['mention_id'] = merged['label_id']
merged['label'] = merged['category']
    
dataset['docs'] = docs

dataset['mentions'] = merged

ner_labels = list(dataset['mentions']['label'].unique())

batcher, prep, deltas, vocs = preprocess_train(
    dataset=dataset,
    max_sentence_length=140,
    bert_name=bert_name,
    ner_labels= ner_labels,
    unknown_labels="drop",
    frag_merge=False,
)

Dataset: Dataset(
  (docs):       792 * ('doc_id', 'text', 'split')
  (mentions):  7059 * ('doc_id', 'mention_id', 'category', 'text', 'begin', 'end', 'label_id', 'label')
  (labels):    7059 * ('label_id', 'doc_id', 'mention_id', 'label')
  (fragments): 6881 * ('doc_id', 'mention_id', 'begin', 'end', 'fragment_id')
)
Transform texts... done
Splitting into sentences... done
Tokenizing... 



          id  token_id  token_idx    token  begin  end  sentence_idx    doc_id  split sentence_id
0          0         0          0    [CLS]      0    0             0  10192393  train         0/0
1          0         1          1        a      0    1             0  10192393  train         0/0
2          0         2          2   common      2    8             0  10192393  train         0/0
3          0         3          3    human      9   14             0  10192393  train         0/0
4          0         4          4     skin     15   19             0  10192393  train         0/0
...      ...       ...        ...      ...    ...  ...           ...       ...    ...         ...
255920  2437    255920        163    ##pop    640  643             1   8696339    dev       791/1
255921  2437    255921        164  ##tosis    643  648             1   8696339    dev       791/1
255922  2437    255922        165        .    648  649             1   8696339    dev       791/1
255923  2437    2559

INFO:nlstruct:Will train vocabulary for category
INFO:nlstruct:Will train vocabulary for ner_label
INFO:nlstruct:Will train vocabulary for token
INFO:nlstruct:Discovered existing vocabulary (28996 entities) for token
INFO:nlstruct:Normalized split, with given vocabulary and no unk
INFO:nlstruct:Normalized split, with given vocabulary and no unk
INFO:nlstruct:Normalized split, with given vocabulary and no unk


done
Processing nestings (overlapping areas)... done
Computing vocabularies...
done


In [11]:
print(batcher)

Batcher(
  [mention]:
    (mention_id): ndarray[int64](6687,)
    (sentence_id): ndarray[int64](6687,)
    (doc_id): ndarray[int64](6687,)
    (begin): ndarray[int64](6687,)
    (end): ndarray[int64](6687,)
    (depth): ndarray[int64](6687,)
    (ner_label): ndarray[int8](6687,)
  [sentence]:
    (sentence_id): ndarray[int64](2438,)
    (doc_id): ndarray[int64](2438,)
    (mention_id): csr_matrix[int64](2438, 20)
    (mention_mask): csr_matrix[bool](2438, 20)
    (token): csr_matrix[int16](2438, 233)
    (token_mask): csr_matrix[bool](2438, 233)
  [doc]:
    (doc_id): ndarray[int64](792,)
    (sentence_id): csr_matrix[int64](792, 6)
    (sentence_mask): csr_matrix[bool](792, 6)
    (split): ndarray[int8](792,)
)


In [11]:
from transformers import BertModel
bert, log = BertModel.from_pretrained(bert_name, output_loading_info=True)

In [12]:
#all_test_doc_ids = []
#sims = {}
#for i in range(200):
seed_all(1234567+137)

train_batcher = batcher['doc'][batcher['doc']['split']==0]['sentence']
test_batcher = batcher['doc'][batcher['doc']['split']==2]['sentence']

splits = np.zeros(len(train_batcher['doc']), dtype=int)

val_perc = 0.1
splits[np.random.choice(np.arange(len(splits)), size=int(val_perc * len(splits)))] = 1

val_batcher = test_batcher#batcher['sentence'][splits == 1]

# train_val_split = np.random.permutation(len(train_batcher))
# test_batcher = train_batcher[train_val_split[:int(0.1*len(train_val_split))]]['sentence']
# train_batcher = train_batcher[train_val_split[int(0.1*len(train_val_split)):]]['sentence']
sim = ((np.bincount(val_batcher['mention', 'ner_label'], minlength=len(vocs["ner_label"]))/len(val_batcher['mention']) -
np.bincount(train_batcher['mention', 'ner_label'], minlength=len(vocs["ner_label"]))/len(train_batcher['mention']))**2).sum()
print("Similarity (L2 dist) between train and val frequencies:", sim)
print("Frequencies")
#all_test_doc_ids.append((test_doc_ids, sim))
display(pd.DataFrame([
    {"index": "train", **dict(zip(vocs["ner_label"], np.bincount(train_batcher['mention', 'ner_label'], minlength=len(vocs["ner_label"]))/len(train_batcher['mention'])))},
    {"index": "val", **dict(zip(vocs["ner_label"], np.bincount(val_batcher['mention', 'ner_label'], minlength=len(vocs["ner_label"]))/len(val_batcher['mention'])))},
]))

Similarity (L2 dist) between train and val frequencies: 0.0013780224108268146
Frequencies


Unnamed: 0,index,CompositeMention,DiseaseClass,Modifier,SpecificDisease
0,train,0.035844,0.155054,0.259766,0.549335
1,val,0.036442,0.126474,0.282958,0.554126


In [13]:
from sklearn.metrics import f1_score

def torch_f1(actions, target_tags, is_training=True):
    f1 = torch.from_numpy(np.array((f1_score(target_tags.reshape(-1).cpu(), actions.reshape(-1).cpu(), average="micro"))))
    f1.requires_grad = is_training
    return f1

class NERNet(torch.nn.Module):
    def __init__(self,
                 n_labels,
                 hidden_dim,
                 dropout,
                 n_tokens=None,
                 token_dim=None,
                 embeddings=None,
                 tag_scheme="bio",
                 max_depth=10,
                 lstm_size=100,
                 ):
        super().__init__()
        
        if embeddings is not None:
            self.embeddings = embeddings
            if n_tokens is None or token_dim is None:
                if hasattr(embeddings, 'weight'):
                    n_tokens, token_dim = embeddings.weight.shape
                else:
                    n_tokens, token_dim = embeddings.embeddings.weight.shape
        else:
            self.embeddings = torch.nn.Embedding(n_tokens, token_dim) if n_tokens > 0 else None
        assert token_dim is not None, "Provide token_dim or embeddings"
        assert self.embeddings is not None

        dim = (token_dim if n_tokens > 0 else 0)
        
        self.ner_labels_subset = list(range(n_labels))

        self.dropout = torch.nn.Dropout(dropout)
        
#         if tag_scheme == "bio":
#             self.crf = BIODecoder(n_labels)
#         elif tag_scheme == "bioul":
#             self.crf = BIOULDecoder(n_labels)
        if tag_scheme == "bio":
            self.crf_list = torch.nn.ModuleList([BIODecoder(1, with_start_end_transitions=False) for _ in self.ner_labels_subset])
        elif tag_scheme == "bioul":
            self.crf_list = torch.nn.ModuleList([BIOULDecoder(1, with_start_end_transitions=False) for _ in self.ner_labels_subset])
        else:
            raise Exception()
        if hidden_dim is None:
            hidden_dim = dim
        self.linear = torch.nn.Linear(dim, hidden_dim)
        self.batch_norm = torch.nn.BatchNorm1d(dim)

        n_tags = self.crf_list[0].num_tags
        
        self.metric_fc = torch.nn.Linear(dim, sum(crf.num_tags for crf in self.crf_list))
        
#         self.max_depth = max_depth
        
#         self.embed_size = dim
#         self.summary_size = n_tags
#         self.lstm_size = n_tags
        
#         self.tag_fc = torch.nn.Linear(dim, dim)
#         self.cell = torch.nn.LSTMCell(self.summary_size, self.lstm_size)
        
#         self.combined_fc = torch.nn.Linear(self.lstm_size + self.embed_size, n_tags)
        
    def forward(self, 
                tokens, 
                mask,
                tags=None,
                return_loss=False,):
        # Embed the tokens
        scores = None
        # shape: n_batch * sequence * 768
        embeds = self.embeddings(tokens)[0]
        
        state = embeds.masked_fill(~mask.unsqueeze(-1), 0)
        state = torch.relu(state)#self.linear(self.dropout(state)))# + state
        state = self.batch_norm(state.view(-1, state.shape[-1])).view(state.shape)
        
        scores = self.metric_fc(state)
        
#         summary_pred = scores.mean(1)
        
#         lstm_hidden = torch.zeros(summary_pred.shape[0], self.lstm_size).to(summary_pred.device)
#         lstm_mem = torch.zeros(summary_pred.shape[0], self.lstm_size).to(summary_pred.device)
        
#         sampled_spans = []
        
#         all_final_scores = []
#         for i_depth in range(self.max_depth):
            
#             lstm_hidden, lstm_mem = self.cell(summary_pred, (lstm_hidden, lstm_mem))
            
#             final_scores = torch.einsum("ijk,ik->ijk", scores,lstm_hidden)
            
#             tags = self.crf.decode(final_scores, mask)
#             sampled_spans.append(self.crf.tags_to_spans(tags, mask))
            
#             all_final_scores.append(final_scores)
            
#         if len(sampled_spans)>0:
#             sampled_spans = {
#                 k: torch.cat([gm[k] for gm in sampled_spans], -1) for k in sampled_spans[0].keys() 
#             }
            
#         all_final_scores = torch.mean(torch.stack(all_final_scores), 0)
        
        scores = scores.reshape((*scores.shape[:-1], -1, len(self.crf_list))).permute(3, 0, 1, 2)
        
        sampled_spans = []
        sampled_tags = []
        baseline_tags = []
            
        for i, (ner_label_idx, crf) in enumerate(zip(self.ner_labels_subset, self.crf_list)):
            sample_tags = crf.sample(scores[i], mask, n=1)
            argmax_tags = crf.decode(scores[i], mask)
            extracted = crf.tags_to_spans(sample_tags[0], mask)
            
            extracted['span_label'] = torch.full_like(extracted["span_label"], ner_label_idx)
        
            sampled_spans.append(extracted)
            sampled_tags.append(sample_tags)
            baseline_tags.append(argmax_tags)
        
        if return_loss:
            loss = -torch.stack([crf(scores[i], mask, tags[0][..., ner_label_idx], reduction="none") 
                                for i, (ner_label_idx, crf) in enumerate(zip(self.ner_labels_subset, self.crf_list))]).mean()
            
        if len(sampled_spans)>0:
            sampled_spans = {
                k: torch.cat([sm[k] for sm in sampled_spans], -1) for k in sampled_spans[0].keys() 
            }
    
        return {
            "scores": scores,
            "sampled_spans": sampled_spans,
            "loss": loss if return_loss else None,
            "sampled_tags": sampled_tags,
            "baseline_tags": sampled_tags,
        }
    
    def init_lstm_state(self, seq_len):
        return (torch.autograd.Variable(torch.zeros(1, seq_len, self.lstm_dim)),
                torch.autograd.Variable(torch.zeros(1, seq_len, self.lstm_dim)))

In [14]:
def select_mention_level(batch_mentions, depth_level):
    tag_mask = batch_mentions["mention", "depth"]==depth_level
    return batch_mentions["mention"][tag_mask]

import traceback
from tqdm import tqdm

from transformers import AdamW, BertModel

from tqdm import tqdm
from scipy.sparse import csr_matrix

from nlstruct.environment import get_cache
from nlstruct.utils import evaluating, torch_global as tg, freeze
from nlstruct.scoring import compute_metrics, merge_pred_and_gold
from nlstruct.train import make_optimizer_and_schedules, iter_optimization, seed_all
from nlstruct.train.schedule import ScaleOnPlateauSchedule, LinearSchedule, ConstantSchedule

from nlstruct.utils import torch_clone
    
device = torch.device('cuda:0')
tg.set_device(device)
all_preds = []
histories = []

# To release gpu memory before allocating new parameters for a new model
# A better idea would be to run xp in a function, so that all variables are released when exiting the fn
# but this way we can debug after this cell if something goes wrong
if "all_nets" in globals(): del all_nets
if "optim" in globals(): del optim, 
if "schedules" in globals(): del schedules
if "final_schedule" in globals(): del final_schedule
if "state" in globals(): del state
    
# Hyperparameter search
layer = 2
hidden_dim = 1024 
scheme = 'bioul' 
seed = 12
lr = 7e-4
bert_lr = 6e-5
dropout = 0.1

seed_all(seed) # /!\ Super important to enable reproducibility

max_grad_norm = 5.
bert_weight_decay = 0.0000
batch_size = 64
random_perm=True
n_freeze = layer + 2
bert_dropout = 0.2

ner_net = NERNet(
        n_tokens=len(vocs["token"]),
        token_dim=1024 if "large" in bert_name else 768,#768,
        embeddings=BertModel.from_pretrained(bert_name),#, custom_embeds_layer_index=custom_embeds_layer_index),
        n_labels = len(vocs['ner_label']),
        dropout=dropout,
        hidden_dim=hidden_dim,
        tag_scheme=scheme,
)
all_nets = torch.nn.ModuleDict({
    "ner_net": ner_net,
}).to(device=tg.device)
del ner_net

for module in all_nets["ner_net"].embeddings.modules():
    if isinstance(module, torch.nn.Dropout):
        module.p = bert_dropout
all_nets.train()

# Define the optimizer, maybe multiple learning rate / schedules per parameters groups
optim = AdamW(params=all_nets['ner_net'].parameters(), lr=lr)

# Freeze some bert layers 
# - n_freeze = 0 to freeze nothing
# - n_freeze = 1 to freeze word embeddings / position embeddings / ...
# - n_freeze = 2..13 to freeze the first, second ... 12th layer of bert
for name, param in all_nets.named_parameters():
    match = re.search("\.(\d+)\.", name)
    if (match and int(match.group(1)) < n_freeze - 1) and ('embeddings' in name):
        freeze([param])
        
# if n_freeze > 0:
#     if hasattr(all_nets['ner_net'].embeddings, 'embeddings'):
#         freeze(all_nets['ner_net'].embeddings.embeddings)
#     else:
#         freeze(all_nets['ner_net'].embeddings)

with_tqdm = True
state = {"all_nets": all_nets, "optim": optim}  # all we need to restart the training from a given epoch

cache = get_cache("genia_rl_tests", {
    "seed": seed, 
    "train_batcher": train_batcher, 
    "val_batcher": None, 
    "random_perm": random_perm,
    "batch_size": batch_size, 
    "max_grad_norm": max_grad_norm, 
    **state,
}, loader=torch.load, dumper=torch.save)

!rm -rf $cache

cache = get_cache("genia_rl_tests", {
    "seed": seed, 
    "train_batcher": train_batcher, 
    "val_batcher": None, 
    "random_perm": random_perm,
    "batch_size": batch_size, 
    "max_grad_norm": max_grad_norm, 
    **state
}, loader=torch.load, dumper=torch.save)

#batch = next(iter(train_batcher['sentence'].dataloader(batch_size=batch_size, shuffle=True, sparse_sort_on=["token_mask"], device=device)))

level = 0

level_train_batcher = train_batcher['mention'].set_join_order(('mention',))[train_batcher['mention','depth']==level]
level_val_batcher = val_batcher['mention'].set_join_order(('mention',))[val_batcher['mention','depth']==level]

for epoch_before, state, history, record in iter_optimization(
    main_score = "val_f1", # do not earlystop based on validation
    metrics_info=metrics_info,
    max_epoch=50,
    patience=50,
    state=state, 
    cache_policy="all", 
    #cache=cache,
    n_save_checkpoints=2,
#             exit_on_score=0.92,
):
    pred_batches = []
    gold_batches = []

    total_train_ner_loss = 0
    total_train_acc = 0
    total_train_ner_size = 0

    n_mentions = len(train_batcher["mention"])
    n_matched_mentions = 0
    n_target_mentions = 0
    n_observed_mentions = 0

    #with tqdm(repeat(batch, 1000), disable=not with_tqdm) as bar:
    with tqdm(train_batcher['sentence'].dataloader(batch_size=batch_size, shuffle=True, sparse_sort_on=["token_mask"], device=device), disable=not with_tqdm) as bar:
        for batch_i, batch in enumerate(bar):
            
            #seed_all(seed)
            
            optim.zero_grad()

            n_samples, sentence_size = batch["sentence", "token"].shape
            n_labels = len(vocs["ner_label"])
            n_depth = batch["mention", "depth"].max() + 1
            
            mask = batch["token_mask"]

            # Compute the tokens label tag of the selected non-overlapping gold mentions to infer from the model
            #target_tags = {}

            #for depth_level in range(3):

            #bm = select_mention_level(batch, depth_level)

            target_tags = BIOULDecoder.spans_to_tags(
                batch["mention", "depth"] * n_labels * n_samples + (batch["mention", "@sentence_id"] * n_labels + batch["mention", "ner_label"]),
                batch["mention", "begin"],
                batch["mention", "end"], 
                torch.zeros_like(batch["mention", "ner_label"]),
                n_tokens=sentence_size,
                n_samples=n_samples * n_labels * n_depth,
            )
            target_tags = target_tags.view(n_depth, n_samples, n_labels, sentence_size)
            target_tags = target_tags.transpose(-1, -2)
            
            #if depth_level==0:
            #    gold_batches.append(bm)

            # WITH DEPTH
            
#             # Compute the tokens label tag of the selected non-overlapping gold mentions to infer from the model
#             n_depth = batch["mention", "depth"].max().item() + 2
#             target_tags = all_nets["ner_net"].crf_list1[0].spans_to_tags(
#                 batch["mention", "@sentence_id"] + batch["mention", "depth"] * n_samples,
#                 batch["mention", "begin"],
#                 batch["mention", "end"], 
#                 batch["mention", "ner_label"], 
#                 n_tokens=sentence_size,
#                 n_samples=n_samples * n_depth,
#             ) # [n_samples * n_labels] * sentence_size
#             target_tags = target_tags.view(n_depth, n_samples, sentence_size)

            res = all_nets["ner_net"].forward(
                tokens=batch["token"],
                mask=mask,
                tags=target_tags,
                return_loss = True,
            )
            
            ner_loss = res['loss']
            scores = res['scores']
            spans = res['sampled_spans']
            sampled_tags = res['sampled_tags']
            baseline_tags = res['baseline_tags']
            
            pred_batch = Batcher({
                "mention": {
                    "mention_id": torch.arange(n_mentions, n_mentions+len(spans['span_doc_id']), device=device),
                    "begin": spans["span_begin"],
                    "end": spans["span_end"],
                    "ner_label": spans["span_label"],
                    "@sentence_id": spans["span_doc_id"],
                },
                "sentence": dict(batch["sentence", ["sentence_id", "doc_id"]]),
                "doc": dict(batch["doc"])}, 
                check=False)
            pred_batches.append(pred_batch)
            n_mentions += len(spans['span_doc_id'])
            
            sampled_f1 = torch_f1(target_tags[0], torch.stack(sampled_tags).squeeze().permute(1,2,0))
            baseline_f1 = torch_f1(target_tags[0], torch.stack(baseline_tags).squeeze().permute(1,2,0))
            
            total_train_ner_loss += float(ner_loss) * len(batch["sentence"])
            total_train_ner_size += len(batch["sentence"])
            
            loss = ner_loss * ((sampled_f1 - baseline_f1) or 1)
            
            if (sampled_f1 - baseline_f1 )!= 0:
                print("AY")
            
            # Perform optimization step
            loss.backward()
            torch.nn.utils.clip_grad_norm_(all_nets.parameters(), max_grad_norm)
            optim.step()
            
            bar.set_postfix(loss=loss.item())
            
    from nlstruct.scoring import compute_metrics, merge_pred_and_gold
    
    train_pred = Batcher.concat(pred_batches)
    val_pred = extract_mentions(val_batcher, all_nets=all_nets)
    
    train_metrics = compute_metrics(merge_pred_and_gold(
        pred=pd.DataFrame(dict(train_pred["mention", ["sentence_id", "begin", "end", "ner_label"]])), #extract_mentions(val_batcher, all_nets=all_nets)
        gold=pd.DataFrame(dict(level_train_batcher["mention", ["sentence_id", "begin", "end", "ner_label"]])), 
        span_policy='exact',  # only partially match spans with strict bounds, we could also eval with 'exact' or 'partial'
        on=["sentence_id", ("begin", "end"), "ner_label"]), prefix='train_')[["train_recall", "train_precision", "train_f1"]].to_dict()

    val_metrics = compute_metrics(merge_pred_and_gold(
        pred=pd.DataFrame(dict(val_pred["mention", ["sentence_id", "begin", "end", "ner_label"]])), #extract_mentions(val_batcher, all_nets=all_nets)
        gold=pd.DataFrame(dict(level_val_batcher["mention", ["sentence_id", "begin", "end", "ner_label"]])), 
        span_policy='exact',  # only partially match spans with strict bounds, we could also eval with 'exact' or 'partial'
        on=["sentence_id", ("begin", "end"), "ner_label"]), prefix='val_')[["val_recall", "val_precision", "val_f1"]].to_dict()
    
    # Compute precision, recall and f1 on train set
#     train_pred = Batcher.concat(pred_batches)
#     train_gold = Batcher.concat(gold_batches)
#     train_metrics    = compute_scores(train_pred, train_gold, prefix='train_')
    
#     val_pred = extract_mentions(val_batcher, all_nets=all_nets)
#     val_gold = select_mention_level(val_batcher, 0)
#     val_metrics     = compute_scores(val_pred, val_gold, prefix='val_')
    
#     test_pred = extract_mentions(test_batcher, all_nets=all_nets)
#     test_gold = select_mention_level(test_batcher, 0)
#     test_metrics = compute_scores(test_pred, test_gold, prefix='test_')
    
    record(
    {
        "train_ner_loss": total_train_ner_loss / max(total_train_ner_size, 1),
        **train_metrics,
        **val_metrics,
        "n_matched": n_matched_mentions,
    })

INFO:nlstruct:Available CUDA devices: 1
INFO:nlstruct:Current device: cuda:0
100%|██████████| 29/29 [00:24<00:00,  1.16it/s, loss=61.3]
INFO:nlstruct:epoch | train_ner_loss | train_f1 | [31mval_f1[0m | n_matched |    dur(s)
INFO:nlstruct:    1 |       [32m154.5636[0m |   [32m0.0032[0m | [32m0.0071[0m |    0.0000 |   26.8717
100%|██████████| 29/29 [00:24<00:00,  1.16it/s, loss=15.8]
INFO:nlstruct:    2 |        [32m92.5503[0m |   [32m0.0042[0m | [32m0.0136[0m |    0.0000 |   26.7806
100%|██████████| 29/29 [00:25<00:00,  1.16it/s, loss=4.47]
INFO:nlstruct:    3 |        [32m47.3702[0m |   [32m0.0044[0m | [32m0.0239[0m |    0.0000 |   26.9177
100%|██████████| 29/29 [00:25<00:00,  1.16it/s, loss=6.21]
INFO:nlstruct:    4 |        [32m24.2289[0m |   [32m0.0122[0m | [32m0.0640[0m |    0.0000 |   26.7759
100%|██████████| 29/29 [00:24<00:00,  1.16it/s, loss=2.83]
INFO:nlstruct:    5 |        [32m16.3060[0m |   [32m0.0415[0m | [32m0.1853[0m |    0.0000 |   26.748

# Test 
(to be fair, avoid executing this part of the notebook to often, or use the training set instead)