In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from nlstruct.dataloaders.medic import get_raw_medic

In [3]:
from __future__ import absolute_import
import argparse
import numpy as np
import torch
import os
import sys
import logging
import pdb

sys.path.insert(0,'/home/ytaille/AttentionSegmentation')

from allennlp.data import Vocabulary
from allennlp.data.iterators import DataIterator
# import allennlp.data.dataset_readers as Readers
import AttentionSegmentation.reader as Readers

# import model as Models
import AttentionSegmentation.model.classifiers as Models

from AttentionSegmentation.commons.utils import \
    setup_output_dir, read_from_config_file
from AttentionSegmentation.commons.model_utils import \
    construct_vocab, load_model_from_existing
# from AttentionSegmentation.visualization.visualize_attns import \
#     html_visualizer
import AttentionSegmentation.model.attn2labels as SegmentationModels

"""The main entry point

This is the main entry point for training HAN SOLO models.

Usage::

    ${PYTHONPATH} -m AttentionSegmentation/main
        --config_file ${CONFIG_FILE}

"""
args = type('MyClass', (object,), {'content':{}})()
args.config_file = 'Configs/config_ncbi.json'
args.log = 'INFO'
args.loglevel = 'INFO'
args.seed = 1

# Setup Experiment Directory
config = read_from_config_file(args.config_file)
if args.seed > 0:
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if config.get('trainer', None) is not None and \
       config.get('trainer', None).get('cuda_device', -1) > 0:
        torch.cuda.manual_seed(args.seed)
serial_dir, config = setup_output_dir(config, args.loglevel)
logger = logging.getLogger(__name__)

# Load Training Data
TRAIN_PATH = config.pop("train_data_path")
logger.info("Loading Training Data from {0}".format(TRAIN_PATH))
dataset_reader_params = config.pop("dataset_reader")
reader_type = dataset_reader_params.pop("type", None)
assert reader_type is not None and hasattr(Readers, reader_type),\
    f"Cannot find reader {reader_type}"
reader = getattr(Readers, reader_type).from_params(dataset_reader_params)
instances_train = reader.read(file_path=TRAIN_PATH)
instances_train = instances_train
logger.info("Length of {0}: {1}".format(
    "Training Data", len(instances_train)))

# Load Validation Data
VAL_PATH = config.pop("validation_data_path")
logger.info("Loading Validation Data from {0}".format(VAL_PATH))
instances_val = reader.read(VAL_PATH)
instances_val = instances_val
logger.info("Length of {0}: {1}".format(
    "Validation Data", len(instances_val)))

# Load Test Data
TEST_PATH = config.pop("test_data_path", None)
instances_test = None
if TEST_PATH is not None:
    logger.info("Loading Test Data from {0}".format(TEST_PATH))
    instances_test = reader.read(TEST_PATH)
    instances_test = instances_test
    logger.info("Length of {0}: {1}".format(
        "Testing Data", len(instances_test)))

# # Load Pretrained Existing Model
# load_config = config.pop("load_from", None)

# # Construct Vocabulary
vocab_size = config.pop("max_vocab_size", -1)
logger.info("Constructing Vocab of size: {0}".format(vocab_size))
vocab_size = None if vocab_size == -1 else vocab_size
vocab = Vocabulary.from_instances(instances_train,
                                  max_vocab_size=vocab_size)
vocab_dir = os.path.join(serial_dir, "vocab")
assert os.path.exists(vocab_dir), "Couldn't find the vocab directory"
vocab.save_to_files(vocab_dir)

# if load_config is not None:
#     # modify the vocab from the source model vocab
#     src_vocab_path = load_config.pop("vocab_path", None)
#     if src_vocab_path is not None:
#         vocab = construct_vocab(src_vocab_path, vocab_dir)
#         # Delete the old vocab
#         for file in os.listdir(vocab_dir):
#             os.remove(os.path.join(vocab_dir, file))
#         # save the new vocab
#         vocab.save_to_files(vocab_dir)
logger.info("Saving vocab to {0}".format(vocab_dir))
logger.info("Vocab Construction Done")

# # Construct the data iterators
logger.info("Constructing Data Iterators")
data_iterator = DataIterator.from_params(config.pop("iterator"))
data_iterator.index_with(vocab)

logger.info("Data Iterators Done")

# Create the model
logger.info("Constructing The model")
model_params = config.pop("model")
model_type = model_params.pop("type")
assert model_type is not None and hasattr(Models, model_type),\
    f"Cannot find reader {model_type}"
model = getattr(Models, model_type).from_params(
    vocab=vocab,
    params=model_params,
    label_indexer=reader.get_label_indexer()
)
logger.info("Model Construction done")

# visualize = config.pop("visualize", False)
# visualizer = None
# if visualize:
#     visualizer = html_visualizer(vocab, reader)
segmenter_params = config.pop("segmentation")
segment_class = segmenter_params.pop("type")
segmenter = getattr(SegmentationModels, segment_class).from_params(
    vocab=vocab,
    reader=reader,
    params=segmenter_params
)

# logger.info("Segmenter Done")

# print("##################################\nAYYYYYYYYYYYYYYYYYYYYYYYY\n\n\n\n\n\n\n\n###########################")

# exit()


# if load_config is not None:
#     # Load the weights, as specified by the load_config
#     model_path = load_config.pop("model_path", None)
#     layers = load_config.pop("layers", None)
#     load_config.assert_empty("Load Config")
#     assert model_path is not None,\
#         "You need to specify model path to load from"
#     model = load_model_from_existing(model_path, model, layers)
#     logger.info("Pretrained weights loaded")

# logger.info("Starting the training process")





2021-03-08 11:07:21,017: INFO: train_data_path = /home/ytaille/data/resources/medic/ncbi_conll_ner_train.conll
2021-03-08 11:07:21,019: INFO: Loading Training Data from /home/ytaille/data/resources/medic/ncbi_conll_ner_train.conll
2021-03-08 11:07:21,022: INFO: dataset_reader.type = WeakConll2003DatasetReader
2021-03-08 11:07:21,026: INFO: dataset_reader.token_indexers.bert.type = bert-pretrained
2021-03-08 11:07:21,028: INFO: dataset_reader.token_indexers.bert.pretrained_model = ./Data/embeddings/bert-base-multilingual-cased-vocab.txt
2021-03-08 11:07:21,029: INFO: dataset_reader.token_indexers.bert.use_starting_offsets = True
2021-03-08 11:07:21,030: INFO: dataset_reader.token_indexers.bert.do_lowercase = False
2021-03-08 11:07:21,031: INFO: dataset_reader.token_indexers.bert.never_lowercase = None
2021-03-08 11:07:21,031: INFO: dataset_reader.token_indexers.bert.max_pieces = 512
2021-03-08 11:07:21,034: INFO: loading vocabulary file ./Data/embeddings/bert-base-multilingual-cased-voc

0it [00:00, ?it/s]

2021-03-08 11:07:21,187: INFO: Reading instances from lines in file at: /home/ytaille/data/resources/medic/ncbi_conll_ner_train.conll


1803it [00:01, 1717.00it/s]

2021-03-08 11:07:22,237: INFO: Length of Training Data: 1803
2021-03-08 11:07:22,239: INFO: validation_data_path = /home/ytaille/data/resources/medic/ncbi_conll_ner_dev.conll
2021-03-08 11:07:22,240: INFO: Loading Validation Data from /home/ytaille/data/resources/medic/ncbi_conll_ner_dev.conll



0it [00:00, ?it/s]

2021-03-08 11:07:22,244: INFO: Reading instances from lines in file at: /home/ytaille/data/resources/medic/ncbi_conll_ner_dev.conll


319it [00:00, 1997.82it/s]

2021-03-08 11:07:22,405: INFO: Length of Validation Data: 319
2021-03-08 11:07:22,406: INFO: test_data_path = /home/ytaille/data/resources/medic/ncbi_conll_ner_test.conll
2021-03-08 11:07:22,407: INFO: Loading Test Data from /home/ytaille/data/resources/medic/ncbi_conll_ner_test.conll



0it [00:00, ?it/s]

2021-03-08 11:07:22,411: INFO: Reading instances from lines in file at: /home/ytaille/data/resources/medic/ncbi_conll_ner_test.conll


316it [00:00, 1198.36it/s]

2021-03-08 11:07:22,675: INFO: Length of Testing Data: 316
2021-03-08 11:07:22,677: INFO: max_vocab_size = -1
2021-03-08 11:07:22,678: INFO: Constructing Vocab of size: -1
2021-03-08 11:07:22,680: INFO: Fitting token dictionary from dataset.



100%|██████████| 1803/1803 [00:00<00:00, 37312.78it/s]


2021-03-08 11:07:22,742: INFO: Saving vocab to ./trained_models/NCBI-BERT-realFT-PS/run-180/vocab
2021-03-08 11:07:22,743: INFO: Vocab Construction Done
2021-03-08 11:07:22,745: INFO: Constructing Data Iterators
2021-03-08 11:07:22,746: INFO: iterator.type = bucket
2021-03-08 11:07:22,747: INFO: iterator.sorting_keys = [['tokens', 'bert']]
2021-03-08 11:07:22,749: INFO: iterator.padding_noise = 0.1
2021-03-08 11:07:22,750: INFO: iterator.biggest_batch_first = False
2021-03-08 11:07:22,751: INFO: iterator.batch_size = 32
2021-03-08 11:07:22,751: INFO: iterator.instances_per_epoch = None
2021-03-08 11:07:22,752: INFO: iterator.max_instances_in_memory = None
2021-03-08 11:07:22,753: INFO: Data Iterators Done
2021-03-08 11:07:22,753: INFO: Constructing The model
2021-03-08 11:07:22,755: INFO: model.type = MultiClassifier
2021-03-08 11:07:22,756: INFO: model.method = binary
2021-03-08 11:07:22,757: INFO: model.text_field_embedder.type = basic
2021-03-08 11:07:22,758: INFO: model.text_field_

  "num_layers={}".format(dropout, num_layers))


In [4]:
# Necessary to add unknown tag to dictionnary to avoid errors later
data_iterator.vocab.add_token_to_namespace("@@UNKNOWN@@", "chunk_tags")

1141

In [5]:
config = read_from_config_file(args.config_file)


In [6]:
from AttentionSegmentation.trainer import Trainer

from nlstruct.utils import  torch_global as tg

trainer = Trainer.from_params(
    model=model,
    base_dir=serial_dir,
    iterator=data_iterator,
    train_data=instances_train,
    validation_data=instances_val,
    segmenter=segmenter,
    params=config.pop("trainer")
)


2021-03-08 11:07:36,816: INFO: PyTorch version 1.5.1 available.
2021-03-08 11:07:40,103: INFO: TensorFlow version 2.3.1 available.
2021-03-08 11:07:40,598: INFO: Loading faiss with AVX2 support.
2021-03-08 11:07:40,602: INFO: Loading faiss.
2021-03-08 11:07:41,051: INFO: trainer.patience = 10
2021-03-08 11:07:41,053: INFO: trainer.validation_metric = +accuracy
2021-03-08 11:07:41,054: INFO: trainer.num_epochs = 50
2021-03-08 11:07:41,055: INFO: trainer.cuda_device = 0
2021-03-08 11:07:41,056: INFO: trainer.grad_norm = None
2021-03-08 11:07:41,057: INFO: trainer.grad_clipping = None
2021-03-08 11:07:41,059: INFO: trainer.num_serialized_models_to_keep = 1
2021-03-08 11:07:43,769: INFO: trainer.optimizer.type = adam
2021-03-08 11:07:43,771: INFO: trainer.optimizer.parameter_groups = [[['.*bert.*'], ConfigTree([('lr', 2e-07)])], [['.*encoder_word.*', '.*attn.*', '.*logit.*'], ConfigTree([('lr', 0.001)])]]
2021-03-08 11:07:43,773: INFO: Converting Params object to dict; logging of default v

In [7]:
# BIT FOR BOOSTING SURROUNDING ATTENTIONS

# attn = torch.Tensor([[0,1,0,1,0], [0,0,1,0,0]])

# attn_boosted = attn.clone()
# nnz = (attn>0).nonzero().t().chunk(chunks=2,dim=0)

# print(nnz)

# new_nnz = [[], []]

# for nz0, nz1 in zip(nnz[0][0].numpy(), nnz[1][0].numpy()):
#     new_nnz[0].extend([nz0,nz0])
#     new_nnz[1].extend([nz1-1,nz1+1])
    
# new_nnz[0] = torch.Tensor([new_nnz[0]]).long()
# new_nnz[1] = torch.Tensor([new_nnz[1]]).long()
# new_nnz = (new_nnz[0], new_nnz[1])
# print(new_nnz)
# attn_boosted[new_nnz] += 0.1

# attn_boosted

In [8]:

# USE BIO BERT
# TRAIN STEP 1 ONLY ON MEDIC LABELS (+ NCBI MENTIONS)
# PREPROCESS / TRAIN / ATTEINDRE BONS SCORES
# GET MEDIC ALTERNATIVE LABELS DANS NLSTRUCT -> TRADUIRE LABELS NCBI VERS MEDIC

# USE ENTROPY INSTEAD OF CROSS ENTROPY -> not rely on labelled data only (rely on model certainty)

# GROUPS : TYPE SEMANTIQUE À LA MENTION (pas utiliser)

# NGRAMS FOR ENTITIES -> not possible with discontinued entities

# Use "separation token" in phrases ?

# Use a limited number of attention heads (not one per class)

# Use same method as Perceval for trajectories (draw closest ones, reduce list, repeat) -> prédiction itérative

# Maybe remove weakly supervised completely?

# Test with Reinforce only after a few epochs

# Facteur de représentation pour pondérer loss de Perceval ?

# Plusieurs facteurs pour constituer la reward

# Facteur de similarité mention extraite / synonyme plutôt que similarité mention / label ?

# Make sure that every trajectory is different -> draw first then use Perceval

# Métrique finale : Est-ce qu'on arrive à choper les CUI ? -> parce que frontières entités dures à déterminer 

# Use only one class ? -> simpler because all mentions are diseases -> MAKE SURE THAT SEVERAL MENTIONS ARE PREDICTABLE

# maybe problem with reinforce algo comes from hyperparameters?? -> USE OPTIMIZER PARAMETER SPECIFICATION

# change objective: instead of WL use RL metrics -> measure on CUI



In [9]:
# USE PERCEVAL WAY OF PREDICTING:

# (Faire tirage attention ?)

# Entrée: Embeddings tokens + embeddings labels 

# -> le masque d'attention sert à déterminer quels tokens fournir
# Pour un CUI prédit: récupérer loss Perceval, comparer avec CUI le plus proche?
# DEFT aide les attentions, mais on abandonne le masque d'attention?


In [10]:
class NERNet(torch.nn.Module):
    def __init__(self,
                 n_labels,
                 hidden_dim,
                 dropout,
                 n_tokens=None,
                 token_dim=None,
                 embeddings=None,
                 tag_scheme="bio",
                 metric='linear',
                 metric_fc_kwargs=None,
                 ):
        super().__init__()
        if embeddings is not None:
            self.embeddings = embeddings
            if n_tokens is None or token_dim is None:
                if hasattr(embeddings, 'weight'):
                    n_tokens, token_dim = embeddings.weight.shape
                else:
                    n_tokens, token_dim = embeddings.embeddings.weight.shape
        else:
            self.embeddings = torch.nn.Embedding(n_tokens, token_dim) if n_tokens > 0 else None
        assert token_dim is not None, "Provide token_dim or embeddings"
        assert self.embeddings is not None

        dim = (token_dim if n_tokens > 0 else 0)
        self.dropout = torch.nn.Dropout(dropout)
        if tag_scheme == "bio":
            self.crf = BIODecoder(n_labels)
        elif tag_scheme == "bioul":
            self.crf = BIOULDecoder(n_labels)
        else:
            raise Exception()
        if hidden_dim is None:
            hidden_dim = dim
        self.linear = torch.nn.Linear(dim, hidden_dim)
        self.batch_norm = torch.nn.BatchNorm1d(dim)

        n_tags = self.crf.num_tags
        metric_fc_kwargs = metric_fc_kwargs if metric_fc_kwargs is not None else {}
        if metric == "linear":
            self.metric_fc = torch.nn.Linear(dim, n_tags)
        elif metric == "cosine":
            self.metric_fc = CosineSimilarity(dim, n_tags, rescale=rescale, **metric_fc_kwargs)
        elif metric == "ema_cosine":
            self.metric_fc = EMACosineSimilarity(dim, n_tags, rescale=rescale, **metric_fc_kwargs)
        else:
            raise Exception()
    
    def extended_embeddings(self, tokens, mask, **kwargs):
        # Default case here, size <= 512
        # Small ugly check to see if self.embeddings is Bert-like, then we need to pass a mask
        if hasattr(self.embeddings, 'encoder') or hasattr(self.embeddings, 'transformer'):
            return self.embeddings(tokens, mask, **kwargs)[0]
        else:
            return self.embeddings(tokens)

    def forward(self, tokens, mask, tag_embeds=None, return_embeddings=False):
        # Embed the tokens
        scores = None
        # shape: n_batch * sequence * 768
        embeds = self.extended_embeddings(tokens, mask, custom_embeds=tag_embeds)
        state = embeds.masked_fill(~mask.unsqueeze(-1), 0)
        state = torch.relu(self.linear(self.dropout(state)))# + state
        state = self.batch_norm(state.view(-1, state.shape[-1])).view(state.shape)
        scores = self.metric_fc(state)
        return {
            "scores": scores,
            "embeddings": embeds if return_embeddings else None,
        }

In [11]:
from __future__ import absolute_import
import logging
import os
import shutil
import json
from collections import deque
import time
import re
import datetime
import traceback
import numpy as np
from typing import Dict, Optional, List, Tuple, Union, Iterable, Any, Set
import pdb

import torch
import torch.optim.lr_scheduler
from torch.nn.parallel import replicate, parallel_apply
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
from tensorboardX import SummaryWriter

from itertools import tee

from allennlp.common import Params
from allennlp.common.checks import ConfigurationError
from allennlp.common.util import peak_memory_mb, gpu_memory_mb
from allennlp.common.tqdm import Tqdm
from allennlp.data.instance import Instance
from allennlp.data.iterators.data_iterator import DataIterator
from allennlp.models.model import Model
from allennlp.nn import util
from allennlp.training.learning_rate_schedulers import LearningRateScheduler
from allennlp.training.optimizers import Optimizer

from AttentionSegmentation.commons.trainer_utils import is_sparse,\
    sparse_clip_norm, move_optimizer_to_cuda, TensorboardWriter
# from AttentionSegmentation.visualization.visualize_attns \
#     import html_visualizer
from AttentionSegmentation.model.attn2labels import BasePredictionClass
logger = logging.getLogger(__name__)

TQDM_COLUMNS = 200

import sys
sys.path.insert(0,'/home/ytaille/deep_multilingual_normalization')
from create_classifiers import create_classifiers
from nlstruct.dataloaders import load_from_brat

logger2 = logging.getLogger("nlstruct")
logger2.setLevel(logging.ERROR)

from notebook_utils import *

def _train_epoch(self, epoch: int) -> Dict[str, float]:
        """
        Trains one epoch and returns metrics.
        """
        logger.info(f"Peak CPU memory usage MB: {peak_memory_mb()}")
        if torch.cuda.is_available():
            for gpu, memory in gpu_memory_mb().items():
                logger.info(f"GPU {gpu} memory usage MB: {memory}")

        train_loss = 0.0

        from allennlp.data.fields.array_field import ArrayField

        for i, td in enumerate(self._train_data):
            td.fields['sample_id'] = ArrayField(np.array([i]))

        # Get tqdm for the training batches
        train_generator = self._iterator(self._train_data,
                                         num_epochs=1,
                                         cuda_device=self._iterator_device,
                                         shuffle=True,
                                         )

        train_generator, cp_generator, id_generator = tee(train_generator, 3)

        ids = []

        for ig in id_generator:
            ids.extend([int(sid.item()) for sid in ig['sample_id']])

        shuffled_train_data = [self._train_data[i] for i in ids]

#         train_predictions = self._segmenter.get_predictions(
#                     instances=shuffled_train_data,
#                     iterator = cp_generator,
#                     model=self._model,
#                     cuda_device=self._iterator_device,
#                     verbose=True)


        num_training_batches = self._iterator.get_num_batches(self._train_data)
        train_generator_tqdm = Tqdm.tqdm(train_generator,
                                         total=num_training_batches
                                         )
        self._last_log = time.time()
        last_save_time = time.time()

        batches_this_epoch = 0
        if self._batch_num_total is None:
            self._batch_num_total = 0

        cpt_batch = 0

        # Set the model to "train" mode.
        self._model.train()

        for batch in train_generator_tqdm:
            
            batches_this_epoch += 1
            self._batch_num_total += 1
            batch_num_total = self._batch_num_total
            batch_len = len(batch['labels'])

            # FOR train_predictions:
            # pred/gold is sentence level
            # pred_labels/gold_labels is word level


            # FOR batch:
            # labels is sentence level
            # tags is word level

            # print(train_texts)
            # print("SENTENCE LEVEL")
            # print([tp['gold'] for tp in train_predictions[:10]])
            # print(batch['labels'][:10])

            # print("WORD LEVEL")
            # print([tp['gold_labels'] for tp in train_predictions[:2]])
            # print(batch['tags'][:2])

            # exit()
            
            if epoch <= -1:
                trajectory_scores =  [0]
            else:
                output_dict = self._model(**batch)
                
#                 attns_single = output_dict['attentions']
                attns = output_dict['attentions']
                
                # Policy is "attention mask": attention scores should be higher if we want to predict CUI
                # Only take words with attention above threshold when predicting with deep norm -> see if it's enough (reward indicates that)
                # REINFORCE algo: (also known as Monte Carlo PG)
                # - draw N trajectories (N attention paths?) -> discretise attentions to make them 1 / 0? -> see if it works with bernoulli first
                # - evaluate each trajectory then sum (maybe add baseline -> subtract mean of all trajectories rewards)
                # - Expected return is given by sum(prob(Ti | W) * reward(Ti)) -> see again if it works with bernoulli first
                # W are WeakL weights 
                # - Gradient ascent of return / gradient descent of negative return

                # Set horizon ? -> number / proportion of attention at 1 per batch
                # Set number of trajectories ? -> maybe make trajectories number vary based on sentence length
                # gamma = 0.9 ? -> used to simulate temporal importance of reward (multiply each step by a certain power of gamma, furthest rewards are less impactful) -> may not be possible to model here
                
                horizon = 0.2
                n_trajectories = 10
                gamma = 0.9
                attn_threshold = 0.01

                mask = batch['tokens']['mask']
        
                prob_attn = attns
                from torch.distributions import Binomial

                m = Binomial(probs=prob_attn)
                trajectory_scores = []#{i: [] for i in range(prob_attn.shape[-1])}
                
                real_tokens = [np.array(b.fields['tokens'].tokens) for b in shuffled_train_data[cpt_batch:cpt_batch+batch_len]]
#                     gold_labels = [np.array(b.fields['tags'].labels) for b in shuffled_train_data[cpt_batch:cpt_batch+batch_len]]
                gold_norm_labels = [np.array(b.fields['chunk_tags'].labels) for b in shuffled_train_data[cpt_batch:cpt_batch+batch_len]]

                policy_loss = []
                
                all_samples = []
    
                for nb_traj in range(n_trajectories):
                    attn_sample = m.sample()
                    
                    all_samples.append(attn_sample)
                    
#                     for class_n in range(attn_sample.shape[-1]):
                        
#                     masked_tokens = [rt[attn_mask[w_id,:len(rt),class_n].cpu().to(bool)]
#                                      if len(rt) > 1 else rt[[attn_mask[w_id,:len(rt),class_n].cpu().to(bool)]] 
#                                      for w_id, rt in enumerate(real_tokens)] # weird behaviour for len == 1
# #                     masked_gold = [rt[attn_mask[w_id,:len(rt)].cpu().to(bool)] for w_id, rt in enumerate(gold_labels)]
#                     masked_gold_norm = [rt[attn_mask[w_id,:len(rt),class_n].cpu().to(bool)] 
#                                         if len(rt) > 1 else rt[[attn_mask[w_id,:len(rt),class_n].cpu().to(bool)]]
#                                         for w_id, rt in enumerate(gold_norm_labels)]
                try:
                    masked_tokens = [rt[attn_mask[w_id,:len(rt),class_n].cpu().to(bool)]
                                     if len(rt) > 1 else rt[[attn_mask[w_id,:len(rt),class_n].cpu().to(bool)]] 
                                     for w_id, rt in enumerate(real_tokens) 
                                     for class_n in range(attn_sample.shape[-1])
                                     for attn_mask in all_samples
                                    ] # weird behaviour for len == 1
    #                     masked_gold = [rt[attn_mask[w_id,:len(rt)].cpu().to(bool)] for w_id, rt in enumerate(gold_labels)]
                    masked_gold_norm = [rt[attn_mask[w_id,:len(rt),class_n].cpu().to(bool)] 
                                        if len(rt) > 1 else rt[[attn_mask[w_id,:len(rt),class_n].cpu().to(bool)]]
                                        for w_id, rt in enumerate(gold_norm_labels) 
                                        for class_n in range(attn_sample.shape[-1])
                                        for attn_mask in all_samples
                                       ]
                except:
                    # IGNORE FOR NOW BUT ERROR SOMETIMES
                    pass

                save_to_ann(masked_tokens, masked_gold_norm, '/home/ytaille/data/tmp/ws_inputs/')

                # NLSTRUCT PART

                bert_name = "bert-base-multilingual-uncased"

                dataset = load_from_brat("/home/ytaille/data/tmp/ws_inputs/")

                if len(dataset['mentions']) == 0:
                    continue

                dataset['mentions']['mention_id'] = dataset['mentions']['doc_id'] +'.'+ dataset['mentions']['mention_id'].astype(str)

                batcher, vocs, mention_ids = preprocess_train(
                    dataset,
                    vocabularies=self.vocabularies1,
                    bert_name=bert_name,
                )

                batch_size = len(batcher)
                with_tqdm = True

                tg.set_device('cuda:0') #('cuda:0')
                device = tg.device

                pred_batcher = predict(batcher, self.classifier1, batch_size=64)
                
                scores = compute_scores(pred_batcher, batcher)

                try:
                    trajectory_scores.append((scores['loss'] * prob_attn).mean())
                except:
                    print(trajectory_scores)
                    raise

                cpt_batch += batch_len

#                 if any(len(tj) > 0 for tj in trajectory_scores.values()):
#                     trajectory_scores = [t for tj in trajectory_scores.values() for t in tj]
#                 else: policy_loss = 0

            self._optimizer.zero_grad()
            loss = sum(trajectory_scores) # policy_loss self._batch_loss(batch, for_training=True) + 
            loss.backward()

            # Make sure Variable is on the cpu before converting to numpy.
            # .cpu() is a no-op if you aren't using GPUs.
            train_loss += loss.data.cpu().numpy()
            batch_grad_norm = self._rescale_gradients()

            # This does nothing if batch_num_total is None or you are using an
            # LRScheduler which doesn't update per batch.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step_batch(batch_num_total)
                
            self._optimizer.step()

            # Update the description with the latest metrics
            metrics = self._get_metrics(train_loss, batches_this_epoch)
            description = self._description_from_metrics(metrics)

            train_generator_tqdm.set_description(description, refresh=False)
            if hasattr(self, "_tf_params") and self._tf_params is not None:
                # We have TF logging
                if self._batch_num_total % self._tf_params["log_every"] == 0:
                    self._tf_log(metrics, self._batch_num_total)

        return self._get_metrics(train_loss, batches_this_epoch, reset=True)
    
import functools

trainer._train_epoch = functools.partial(_train_epoch, trainer)

In [12]:
trainer.train()

2021-03-08 11:08:20,564: INFO: Beginning training.
2021-03-08 11:08:20,567: INFO: Starting Training Epoch 1/50
2021-03-08 11:08:20,569: INFO: Peak CPU memory usage MB: 6870.372
2021-03-08 11:08:20,674: INFO: GPU 0 memory usage MB: 5750


  return_array[slices] = self.array
CompositeMention: 0.0479, DiseaseClass: 0.1958, Modifier: 0.4208, SpecificDisease: 0.5646, accuracy: 0.3073, loss: 0.1034 ||: 100%|██████████| 57/57 [02:16<00:00,  2.39s/it]

2021-03-08 11:10:38,805: INFO: Starting with Validation



CompositeMention: 0.0878, DiseaseClass: 0.2382, Modifier: 0.3981, SpecificDisease: 0.5611, accuracy: 0.3213, loss: 0.7400 ||: 100%|██████████| 10/10 [00:01<00:00,  6.76it/s]

2021-03-08 11:10:40,289: INFO: Validation done. (350.0 / 1276) zero predicted





2021-03-08 11:10:48,580: INFO: Best validation performance so far. Copying weights to './trained_models/NCBI-BERT-realFT-PS/run-180/models/best.th'.
2021-03-08 11:10:55,462: INFO: Metrics:
                        Training DiseaseClass      : 0.241  Validation DiseaseClass      : 0.238
                        Training CompositeMention  : 0.052  Validation CompositeMention  : 0.088
                        Training SpecificDisease   : 0.631  Validation SpecificDisease   : 0.561
                        Training accuracy          : 0.332  Validation accuracy          : 0.321
                        Training Modifier          : 0.404  Validation Modifier          : 0.398
                        Training loss              : 0.027  Validation loss              : 0.740

2021-03-08 11:10:55,463: INFO: Writing validation visualization at ./trained_models/NCBI-BERT-realFT-PS/run-180/visualization/validation.html


100%|██████████| 10/10 [00:02<00:00,  3.58it/s]


2021-03-08 11:10:58,779: INFO: Tag: CompositeMention, Acc: 8.78
2021-03-08 11:10:58,781: INFO: Tag: DiseaseClass, Acc: 23.82
2021-03-08 11:10:58,783: INFO: Tag: Modifier, Acc: 39.81
2021-03-08 11:10:58,784: INFO: Tag: SpecificDisease, Acc: 56.11
2021-03-08 11:10:58,785: INFO: Average ACC: 32.13
2021-03-08 11:10:58,834: INFO: processed 22501 tokens with 119 phrases; 
2021-03-08 11:10:58,835: INFO: found: 756 phrases; correct: 0.

2021-03-08 11:10:58,836: INFO: accuracy:  92.78%; 
2021-03-08 11:10:58,837: INFO: precision:   0.00%; recall:   0.00%; FB1:   0.00
2021-03-08 11:10:58,838: INFO:               TAG  precision   recall      FB1
2021-03-08 11:10:58,839: INFO:  CompositeMention      0.00%    0.00%    0.00%
2021-03-08 11:10:58,840: INFO:      DiseaseClass      0.00%    0.00%    0.00%
2021-03-08 11:10:58,841: INFO:          Modifier      0.00%    0.00%    0.00%
2021-03-08 11:10:58,842: INFO:   SpecificDisease      0.00%    0.00%    0.00%
2021-03-08 11:10:58,845: INFO: Writing predict

  return_array[slices] = self.array
CompositeMention: 0.0802, DiseaseClass: 0.2290, Modifier: 0.3953, SpecificDisease: 0.6027, accuracy: 0.3268, loss: 0.0932 ||:  11%|█         | 6/57 [00:34<04:55,  5.79s/it]

2021-03-08 11:11:34,707: ERROR: Internal Python error in the inspect module.
Below is the traceback from this internal error.

Traceback (most recent call last):
  File "/home/ytaille/.conda/envs/deep_multilingual_normalization/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3343, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-12-3435b262f1ae>", line 1, in <module>
    trainer.train()
  File "/home/ytaille/AttentionSegmentation/AttentionSegmentation/trainer.py", line 61, in train
    super(Trainer, self).train(*args, **kwargs)
  File "/home/ytaille/AttentionSegmentation/AttentionSegmentation/commons/trainer.py", line 588, in train
    train_metrics = self._train_epoch(epoch)
  File "<ipython-input-11-6104dadf68b1>", line 226, in _train_epoch
    bert_name=bert_name,
  File "/home/ytaille/AttentionSegmentation/notebook_utils.py", line 513, in preprocess_train
    tokenizer = AutoTokenizer.from_pretrained(bert_name)
  File "/h




TypeError: object of type 'NoneType' has no len()

In [None]:
logger.info("Training Done.")
if instances_test is not None:
    logger.info("Computing final Test Accuracy")
    trainer.test(instances_test)
logger.info("Done.")