In [1]:
%load_ext autoreload
%autoreload 2

# IMPORTS

In [2]:
import sys
import pickle as pkl

import string
import logging

from nlstruct.utils import torch_clone
from nlstruct.utils import torch_global as tg
from nlstruct.dataloaders import load_from_brat
from nlstruct.collections import Batcher

sys.path.insert(0,'./deep_multilingual_normalization')

from deep_multilingual_normalization.preprocess import preprocess, load_quaero
from deep_multilingual_normalization.train import train_step1, train_step2, clear
from deep_multilingual_normalization.eval import predict, compute_scores

from notebook_utils import *

# PATHS

In [3]:
VOCAB_PATH = '/home/ytaille/christel_g'
PREDICT_DATA_PATH = '/home/ytaille/data/resources/mantra/Mantra-GSC/English/Medline_EN_FR_ec22-cui-best_man'

# Load classifiers

In [4]:
logging.basicConfig(level=logging.INFO, format="", stream=sys.stdout)
logging.getLogger("transformers").setLevel(logging.ERROR)

subs = [
    (r"(?<=[{}\\])(?![ ])".format(string.punctuation), r" "),
    (r"(?<![ ])(?=[{}\\])".format(string.punctuation), r" "),
    ("(?<=[a-zA-Z])(?=[0-9])", r" "),
    ("(?<=[0-9])(?=[A-Za-z])", r" "),
    ("[ ]{2,}", " ")
]

bert_name = "bert-base-multilingual-uncased"
tg.set_device('cuda:0')
device = tg.device

# Step 1
train_batcher, vocabularies, train_mentions, train_mention_ids, group_label_mask, quaero_batcher, quaero_mentions, quaero_mention_ids = preprocess(
    bert_name=bert_name,
    umls_versions=["2021AA"],
    source_full_lexicon=True,
    source_lat=["FRE"],
    add_quaero_splits=["train"],
    other_full_lexicon=False,
    other_lat=["ENG"],
    other_additional_labels=None,
    other_mirror_existing_labels=True,
    sty_groups=['ANAT', 'CHEM', 'DEVI', 'DISO', 'GEOG', 'LIVB', 'OBJC', 'PHEN', 'PHYS', 'PROC'],
    other_sabs=None,
    subs=subs,
    apply_unidecode=True,
    max_length=100,
)

history, classifier = train_step1(
    # Data
    train_batcher=train_batcher,
    val_batcher=quaero_batcher[quaero_batcher['split'] == 0],
    vocabularies=vocabularies,
    group_label_mask=group_label_mask,

    # Learning rates
    metric_lr=8e-3,
    inter_lr=8e-3,
    bert_lr=2e-5,

    # Misc params
    metric='clustered_cosine',
    dim=350,
    rescale=20,
    bert_name=bert_name,
    batch_norm_affine=True,
    batch_norm_momentum=0.1,
    train_with_groups=True,

    # Regularizers
    dropout=0.2,
    bert_dropout=0.2,
    mask_and_shuffle=(2, 0.5, 0.1),
    n_freeze=0,
    sort_noise=1.,
    n_neighbors=None,

    # Scheduling
    batch_size=128,
    bert_warmup=0.1,
    max_epoch=15,
    decay_schedule="linear",

    # Experiment params
    seed=123456,
    stop_epoch=None,
    with_cache=True,
    debug=False,
    from_tf=False,
    with_tqdm=True,
)


# with open('/home/ytaille/AttentionSegmentation/vocab1.pkl', 'wb') as f:
#     pkl.dump(vocabularies, f)

Available CUDA devices: 1
Current device: cuda:0
Using cache /home/ytaille/christel_g/cache/preprocess_training_data/84f2788785216393
Loading /home/ytaille/christel_g/cache/preprocess_training_data/84f2788785216393/output.pkl... 
Loading MRCONSO...
Deduplicating MRCONSO...
French synonyms: 389221
French labels: 152156
Quaero mentions: 5714
Mirrored labels: 152828
Queried english labels: 152828
Total deduplicated synonyms: 1430324
Total deduplicated labels: 152828
Will train vocabulary for label
Will train vocabulary for group
Will train vocabulary for source
Will train vocabulary for token
Discovered existing vocabulary (105879 entities) for token
Normalized split, with given vocabulary and no unk
Normalized split, with given vocabulary and no unk
2828
Total deduplicated synonyms: 1430324
Total deduplicated labels: 152828
Will train vocabulary for label
Will train vocabulary for group
Will train vocabulary for source
Will train vocabulary for token
Discovered existing vocabulary (10587

In [5]:
# Step 2
train_batcher, vocabularies, train_mentions, train_mention_ids, group_label_mask, quaero_batcher, quaero_mentions, quaero_mention_ids = preprocess(
    bert_name=bert_name,
    umls_versions=["2021AA"],
    source_full_lexicon=True,
    source_lat=["FRE"],
    add_quaero_splits=["train"],
    other_full_lexicon=True,
    other_lat=["ENG"],
    other_additional_labels=None,
    other_mirror_existing_labels=True,
    sty_groups=['ANAT', 'CHEM', 'DEVI', 'DISO', 'GEOG', 'LIVB', 'OBJC', 'PHEN', 'PHYS', 'PROC'],
    other_sabs=["CHV", "SNOMEDCT_US", "MTH", "NCI", "MSH"],
    subs=subs,
    apply_unidecode=True,
    max_length=100,
)

# UNCOMMENT HERE TO SWITCH TO CPU

# out_features est un paramètre pas mis à jour

# classifier.cpu()
classifier2 = train_step2(
    classifier=torch_clone(classifier).to(tg.device),
    train_batcher=train_batcher,
    val_batcher=quaero_batcher[quaero_batcher['split'] == list(vocabularies['split']).index('dev')],
    group_label_mask=group_label_mask,
    batch_size=128,
    sort_noise=1.,
    decay_schedule="linear",
    lr=8e-3,
    n_epochs=5,
    seed=123456,
    rescale=20,
    n_neighbors=100,
)

# with open('/home/ytaille/AttentionSegmentation/vocab2.pkl', 'wb') as f:
#     pkl.dump(vocabularies, f)

Using cache /home/ytaille/christel_g/cache/preprocess_training_data/4d2b03f45f152b1f
Loading /home/ytaille/christel_g/cache/preprocess_training_data/4d2b03f45f152b1f/output.pkl... 
Loading MRCONSO...
Deduplicating MRCONSO...
French synonyms: 389221
French labels: 152156
Quaero mentions: 5714
Mirrored labels: 152828
Queried english labels: 152828
Adding all english concepts from SABs: ['CHV', 'SNOMEDCT_US', 'MTH', 'NCI', 'MSH']
Total deduplicated synonyms: 3641321
Total deduplicated labels: 1054783
Will train vocabulary for label
Will train vocabulary for group
Will train vocabulary for source
Will train vocabulary for token
Discovered existing vocabulary (105879 entities) for token
Normalized split, with given vocabulary and no unk
Normalized split, with given vocabulary and no unk
Quaero mentions: 16283
Normalized split, with given vocabulary and no unk
Normalized label, with given vocabulary and no unk
Normalized quaero_source, with given vocabulary and no unk
Normalized token, with 

In [6]:
bert_name = "bert-base-multilingual-uncased"

import os

with open(os.path.join(VOCAB_PATH,'vocab1.pkl'), 'rb') as f:
    vocabularies1 = pkl.load(f)

with open(os.path.join(VOCAB_PATH,'vocab2.pkl'), 'rb') as f:
    vocabularies2 = pkl.load(f)

In [7]:
dataset = load_from_brat(PREDICT_DATA_PATH)

dataset['mentions']['mention_id'] = dataset['mentions']['doc_id'] +'.'+ dataset['mentions']['mention_id'].astype(str)
dataset['fragments']['mention_id'] = dataset['fragments']['doc_id'] +'.'+ dataset['fragments']['mention_id'].astype(str)

batcher, vocs, mention_ids = preprocess_train(
    dataset,
    vocabularies=vocabularies2,
    bert_name=bert_name,
)

batch_size = len(batcher)
with_tqdm = True

tg.set_device('cuda:0') #('cuda:0')
device = tg.device

pred_batcher = predict(batcher, classifier2, batch_size=64, return_loss=False)

Will train vocabulary for label
Normalized token, with given vocabulary and no unk
Available CUDA devices: 1
Current device: cuda:0


In [8]:
pred_batcher = predict(batcher, classifier2, batch_size=64, return_loss=True)
compute_scores(pred_batcher, batcher)

{'recall': 0.704225352112676,
 'precision': 0.704225352112676,
 'f1': 0.704225352112676,
 'pred_count': 284.0,
 'gold_count': 284.0,
 'tp': 200.0,
 'loss': 1.2823140587605222,
 'map': 0.8039885340995149}

In [None]:
# MANTRA

# FR:
# MEDLINE: 0.634
# EMEA: 0.610

# EN:
# MEDLINE: 0.704
# EMEA: 0.651

# QUAERO:

# MEDLINE: 0.708
# EMEA: -> Duplicated mention_id, peut-être pas besoin de résoudre le pb...

In [None]:
voc_tokens = vocs['token']
voc_labels = vocs['label']

def replace_fn(s):
    return s.replace(' ##', '').replace('[SEP]', '').replace('[PAD]', '').replace('[CLS]', '').strip()

merged_batcher = batcher['mention',['mention_id','token']].merge(pred_batcher)

final_tokens = [replace_fn(' '.join([voc_tokens[i] for i in b])) for b in merged_batcher['mention']['token'].toarray()]
final_labels = [voc_labels[b] for b in merged_batcher['mention']['label']]

final_couples = [(t, l) for t,l in zip(final_tokens, final_labels)]



In [61]:
def export_to_brat(dataset, norm_labels, dest=None, filename_prefix=""):
    doc_id_to_text = dict(zip(dataset["docs"]["doc_id"], dataset["docs"]["text"]))
    counter = 0
    mention_counter = 0
    mentions = dataset["mentions"]
    
    if "begin" not in mentions:
        mentions = mentions.merge(dataset["fragments"])
    if dest is not None:
        try:
            os.mkdir(dest)
        except FileExistsError:
            pass
    
    for doc_id, text in doc_id_to_text.items():
        if not os.path.exists("{}/{}.txt"):
            with open("{}/{}.txt".format(dest, filename_prefix + doc_id), "w") as f:
                f.write(text)
        doc_mentions = mentions[mentions.doc_id.str.contains(doc_id)]
        counter += 1
        f = None
        if dest is not None:
            f = open("{}/{}.ann".format(dest, filename_prefix + doc_id), "w")
        try:
            mention_i = 0
            for _, row in doc_mentions.iterrows():
                mention_text = text[row["begin"]:row["end"]]
                idx = row["begin"]
                mention_i += 1
                spans = []
                for part in mention_text.split("\n"):
                    begin = idx
                    end = idx + len(part)
                    idx = end + 1
                    if begin != end:
                        spans.append((begin, end))
                print("T{}\t{} {}\t{}".format(
                    mention_i,
                    norm_labels[mention_counter][0], #str(row["label"]),
                    ";".join(" ".join(map(str, span)) for span in spans),
                    mention_text.replace("\n", " ")), file=f)
                mention_counter += 1
        finally:
            if f is not None:
                f.close()


In [62]:
export_to_brat(dataset, final_labels, dest="/home/ytaille/pyner/preds/medline_norm_preds")

In [None]:
# load quaero with CUI as labels instead of NER