In [1]:
import nlstruct

In [2]:
from nlstruct.dataloaders.medic import *
from nlstruct.dataloaders.ncbi_disease import *

In [3]:
from nlstruct.text import huggingface_tokenize, regex_sentencize, partition_spans, encode_as_tag, split_into_spans, apply_substitutions, apply_deltas
from nlstruct.dataloaders import load_from_brat, load_genia_ner
from nlstruct.collections import Dataset, Batcher
from nlstruct.utils import merge_with_spans, normalize_vocabularies, factorize_rows, df_to_csr, factorize, torch_global as tg
# from nlstruct.modules.crf import BIODecoder, BIOULDecoder
from nlstruct.environment import root, cached
from nlstruct.train import seed_all
from itertools import chain, repeat

import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
import math

import pandas as pd
import numpy as np
import re
import string
from transformers import AutoModel, AutoTokenizer

from nlstruct.utils import encode_ids

In [4]:
# # bit of code for conll to ann

# full_text = ' '.join(a[0]['text'])

# cpt_car = 0

# mention_id = 0
# mention_begin = 0
# mention_end = 0
# mention_label = ''
# mention_text = ''

# # T1	sosy 95 107	vomissements
# # DUMP vocs for prediction
# ann_str = ''

# for t, g in zip(a[0]['text'], a[0]['gold_labels']):
    
    
#     if mention_text != '' and g[0] in ['B', 'O']:
#         mention_end = cpt_car - 1

#         ann_str += f'T{mention_id}\t{mention_label} {mention_begin} {mention_end}\t{mention_text.strip(" ")}\n'

#         mention_text = ''
#         mention_id += 1
        
#     if g[0] == 'B':
#         mention_begin = cpt_car
#         mention_text = t + ' '
#         mention_label = g[2:]
        
#     elif g[0] == 'I':
#         mention_text += t + ' '
            
#     cpt_car += len(t) + 1

In [5]:
#@cached
def preprocess_train(
    dataset,
    max_sentence_length,
    bert_name,
    ner_labels=None,
    unknown_labels="drop",
    vocabularies=None,
):
    """
    Parameters
    ----------
        dataset: Dataset
        max_sentence_length: int
            Max number of "words" as defined by the regex in regex_sentencize (so this is not the nb of wordpieces)
        bert_name: str
            bert path/name
        ner_labels: list of str 
            allowed ner labels (to be dropped or filtered)
        unknown_labels: str
            "drop" or "raise"
        vocabularies: dict[str; np.ndarray or list]
    Returns
    -------
    (pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, dict[str; np.ndarray or list])
        docs:      ('split', 'text', 'doc_id')
        sentences: ('split', 'doc_id', 'sentence_idx', 'begin', 'end', 'text', 'sentence_id')
        mentions:  ('ner_label', 'doc_id', 'sentence_id', 'mention_id', 'depth', 'text', 'mention_idx', 'begin', 'end')
        tokens:    ('split', 'token', 'sentence_id', 'token_id', 'token_idx', 'begin', 'end', 'doc_id', 'sentence_idx')
        deltas:    ('doc_id', 'begin', 'end', 'delta')
        vocs: vocabularies to be reused later for encoding more data or decoding predictions
    """
    print("Dataset:", dataset)
    mentions = dataset["mentions"].rename({"label": "ner_label"}, axis=1)
    
    if ner_labels is not None:
        len_before = len(mentions)
        unknown_ner_labels = list(mentions[~mentions["ner_label"].isin(ner_labels)]["ner_label"].drop_duplicates())
        
        mentions = mentions[mentions["ner_label"].isin(ner_labels)]
        if len(unknown_ner_labels) and unknown_labels == "raise":
            raise Exception(f"Unkown labels in {len_before-len(mentions)} mentions: ", unknown_ner_labels)

    # Check that there is no mention overlap
#     mentions = mentions.merge(dataset["fragments"].groupby(["doc_id", "mention_id"], as_index=False, observed=True).agg({"begin": "min", "end": "max"}))
    
    print("Transform texts...", end=" ")
    docs, deltas = apply_substitutions(
        dataset["docs"], *zip(
            #(r"(?<=[{}\\])(?![ ])".format(string.punctuation), r" "),
            #(r"(?<![ ])(?=[{}\\])".format(string.punctuation), r" "),
            ("(?<=[a-zA-Z])(?=[0-9])", r" "),
            ("(?<=[0-9])(?=[A-Za-z])", r" "),
        ), apply_unidecode=True)
    docs = docs.astype({"text": str})
    transformed_mentions = apply_deltas(mentions, deltas, on=['doc_id'])
    print("done")
    
    print("Splitting into sentences...", end=" ")
    sentences = regex_sentencize(
        docs, 
        reg_split=r"(?<=[.])(\s*\n+)|(?=, [0-9]\))",
        min_sentence_length=None, max_sentence_length=max_sentence_length,
        balance_parentheses=False,
        # balance_parentheses=True, # default is True
    )
          
    [mentions], sentences, sentence_to_docs = partition_spans([transformed_mentions], sentences, new_id_name="sentence_id", overlap_policy=False)
    

    mentions = mentions[(mentions['begin']>0) & (mentions['end']>0)]
    
    n_sentences_per_mention = mentions.assign(count=1).groupby(["doc_id", "mention_id"], as_index=False).agg({"count": "sum", "text": tuple, "sentence_id": tuple})
    if n_sentences_per_mention["count"].max() > 1:
        display(n_sentences_per_mention.query("count > 1"))
        display(sentences[sentences["sentence_id"].isin(n_sentences_per_mention.query("count > 1")["sentence_id"].explode())]["text"].tolist())
        raise Exception("Some mentions could be mapped to more than 1 sentences ({})".format(n_sentences_per_mention["count"].max()))
    if sentence_to_docs is not None:
        mentions = mentions.merge(sentence_to_docs)
        
    mentions = mentions.assign(mention_idx=0).nlstruct.groupby_assign(["doc_id", "sentence_id"], {"mention_idx": lambda x: tuple(range(len(x)))})
    print("done")
    
    print("Tokenizing...", end=" ")
    tokenizer = AutoTokenizer.from_pretrained(bert_name)
    sentences["text"] = sentences["text"].str.lower()
    tokens = huggingface_tokenize(sentences, tokenizer, doc_id_col="sentence_id")

    prep = Dataset(docs=docs, sentences=sentences, mentions=mentions, tokens=tokens).copy()    
    
    mentions = split_into_spans(mentions, tokens, pos_col="token_idx", overlap_policy=False)
    print("done")
    
    print("Processing nestings (overlapping areas)...", end=" ")
    # Extract overlapping spans
    conflicts = (
        merge_with_spans(mentions, mentions, on=["doc_id", "sentence_id", ("begin", "end")], how="outer", suffixes=("", "_other"))
    )
    # ids1, and ids2 make the edges of the overlapping mentions of the same type (see the "ner_label")
    [ids1, ids2], unique_ids = factorize_rows(
        [conflicts[["doc_id", "sentence_id", "mention_id"]], 
         conflicts[["doc_id", "sentence_id", "mention_id_other"]]],
        mentions.eval("size=(end-begin)").sort_values("size")[["doc_id", "sentence_id", "mention_id"]]
    )
    g = nx.from_scipy_sparse_matrix(df_to_csr(ids1, ids2, n_rows=len(unique_ids), n_cols=len(unique_ids)))
    colored_nodes = np.asarray(list(nx.coloring.greedy_color(g, strategy=keep_order).items()))
    unique_ids['depth'] = colored_nodes[:, 1][colored_nodes[:, 0].argsort()]
    mentions = mentions.merge(unique_ids)
    print("done")
    
    print("Computing vocabularies...")
    [docs, sentences, mentions, tokens], vocs = normalize_vocabularies(
        [docs, sentences, mentions, tokens], 
        vocabularies={"split": ["train", "val", "test"]} if vocabularies is None else vocabularies,
        train_vocabularies={"source": False, "text": False} if vocabularies is None else False,
        verbose=True)
    print("done")
    
    unique_mention_ids = encode_ids([mentions], ("doc_id", "sentence_id", "mention_id"))
    unique_sentence_ids = encode_ids([sentences, mentions, tokens], ("doc_id", "sentence_id"))
    unique_doc_ids = encode_ids([docs, sentences, mentions, tokens], ("doc_id",))
    
    batcher = Batcher({
        "mention": {
            "mention_id": mentions["mention_id"],
            "sentence_id": mentions["sentence_id"],
            "doc_id": mentions["doc_id"],
            "begin": mentions["begin"],
            "end": mentions["end"],
            "depth": mentions["depth"],
            "ner_label": mentions["ner_label"].cat.codes,
        },
        "sentence": {
            "sentence_id": sentences["sentence_id"],
            "doc_id": sentences["doc_id"],
            "mention_id": df_to_csr(mentions["sentence_id"], mentions["mention_idx"], mentions["mention_id"], n_rows=len(unique_sentence_ids)),
            "mention_mask": df_to_csr(mentions["sentence_id"], mentions["mention_idx"], n_rows=len(unique_sentence_ids)),
            "token": df_to_csr(tokens["sentence_id"], tokens["token_idx"], tokens["token"].cat.codes, n_rows=len(unique_sentence_ids)),
            "token_mask": df_to_csr(tokens["sentence_id"], tokens["token_idx"], n_rows=len(unique_sentence_ids)),
        },
        "doc": {
            "doc_id": np.arange(len(unique_doc_ids)),
            "sentence_id": df_to_csr(sentences["doc_id"], sentences["sentence_idx"], sentences["sentence_id"], n_rows=len(unique_doc_ids)),
            "sentence_mask": df_to_csr(sentences["doc_id"], sentences["sentence_idx"], n_rows=len(unique_doc_ids)),
            "split": docs["split"].cat.codes,
        }},
        masks={"sentence": {"token": "token_mask", "mention_id": "mention_mask"}, 
               "doc": {"sentence_id": "sentence_mask"}}
    )
    
    return batcher, prep, deltas, vocs

def keep_order(G, colors):
    """Returns a list of the nodes of ``G`` in ordered identically to their id in the graph
    ``G`` is a NetworkX graph. ``colors`` is ignored.
    This is to assign a depth using the nx.coloring.greedy_color function
    """
    return sorted(list(G))

In [6]:
bert_name = "bert-base-cased"
dataset = load_ncbi_disease()

print(dataset['mentions'])

docs = dataset['docs']

keep_n_first = None

if keep_n_first:
    docs = docs[:keep_n_first]
    first_ids = docs['doc_id']
    
    first_mentions = dataset["mentions"].loc[dataset["mentions"]['doc_id'].isin(first_ids)]
    first_fragments = dataset["fragments"].loc[dataset["fragments"]['doc_id'].isin(first_ids)]
    first_attributes = dataset["attributes"].loc[dataset["attributes"]['doc_id'].isin(first_ids)]
    
    dataset["mentions"] = first_mentions
    dataset["fragments"] = first_fragments
    dataset["attributes"] = first_attributes

dataset['mentions'] = dataset['mentions'].merge(dataset["fragments"].groupby(["doc_id", "mention_id"], as_index=False, observed=True).agg({"begin": "min", "end": "max"}))

merged = dataset['mentions'].merge(dataset['labels'], on=['doc_id', 'mention_id'])
merged['mention_id'] = merged['label_id']
merged['norm_label'] = merged['label'] # rename norm labels
merged['label'] = merged['category']
    
dataset['docs'] = docs

dataset['mentions'] = merged

ner_labels = list(dataset['mentions']['label'].unique())

batcher, prep, deltas, vocs = preprocess_train(
    dataset=dataset,
    max_sentence_length=140,
    bert_name=bert_name,
    ner_labels= ner_labels,
    unknown_labels="drop",
)

        doc_id  mention_id         category  \
0     10192393  10192393-0     DiseaseClass   
1     10192393  10192393-1     DiseaseClass   
2     10192393  10192393-2     DiseaseClass   
3     10192393  10192393-3  SpecificDisease   
4     10192393  10192393-4  SpecificDisease   
...        ...         ...              ...   
6876   8696339   8696339-0  SpecificDisease   
6877   8696339   8696339-1  SpecificDisease   
6878   8696339   8696339-2         Modifier   
6879   8696339   8696339-3  SpecificDisease   
6880   8696339   8696339-4     DiseaseClass   

                                     text  
0                             skin tumour  
1                                  cancer  
2                           colon cancers  
3              adenomatous polyposis coli  
4                                     APC  
...                                   ...  
6876                   Huntington disease  
6877                                   HD  
6878                                   



done
Processing nestings (overlapping areas)... done
Computing vocabularies...
done


In [24]:
from os import path
from pathlib import Path 

split = 'all'

all_sentences = prep['sentences']#[prep['sentences']['split']==split]
all_mentions = prep['mentions']

def dataset_to_ann(
    sentences, 
    mentions,
    ann_path,
):
    
    Path(ann_path).mkdir(parents=True, exist_ok=True)
    
    for doc_id in all_sentences['doc_id'].unique():
        
        doc_char_count = 0
        
        doc_sentences = all_sentences[all_sentences.doc_id==doc_id]
        
        doc_str = ' '.join([d['text'] for _,d in doc_sentences.iterrows()])
        ann_str = ''
        mention_id = 0
        
        for _, sentence in doc_sentences.iterrows():
            sentence_id = sentence['sentence_id']
            first_mentions = all_mentions[all_mentions['sentence_id']==sentence_id]
            first_mentions['begin'] += doc_char_count 
            first_mentions['end'] += doc_char_count 
            
            for _, fm in first_mentions.iterrows():
                ann_str += f"T{mention_id}\t{fm['norm_label']} {fm['begin']} {fm['end']}\t{fm['text']}\n"
                mention_id += 1
                
            doc_char_count += len(sentence['text']) + 1

        with open(path.join(ann_path, f'{doc_id}.ann'), 'w') as f:
            f.write(ann_str)
                
        with open(path.join(ann_path, f'{doc_id}.txt'), 'w') as f:
            f.write(doc_str)
        
dataset_to_ann(all_sentences, all_mentions,
    f"/home/ytaille/data/resources/medic/standoff_cui/{split}/")
        

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


In [13]:
def dataset_to_conll():

    split = 'dev'

    text_file = f'ncbi_conll_ner_{split}.conll'
    all_sentences = prep['sentences'][prep['sentences']['split']==split]
    all_mentions = prep['mentions']#[prep['mentions']['split']==split]

    all_text = []

    for _, sentence in all_sentences.iterrows():
        sentence_id = sentence['sentence_id']
        first_mentions = all_mentions[all_mentions['sentence_id']==sentence_id]

        if sentence['text'][-2::] != ' .':
            sentence['text'] = sentence['text'][:-1] + ' .'

        sentence['text'] = sentence['text'].replace('\n', ' ')

        i = 0
        text = []
        mentions = []
        norm_mentions = []

        already_seen = []

        if len(first_mentions) == 0:
            bef_text = sentence['text'].split(' ')
            text.extend(bef_text)
            mentions.extend(['O'] * len(bef_text))
            norm_mentions.extend(['O'] * len(bef_text))
        else:
            for _, fm in first_mentions.iterrows():
                begin, end, mention_text, label, norm_label = fm[['begin', 'end', 'text', 'ner_label', 'norm_label']]

                if any([(begin in a) or (end in a) for a in already_seen]):
                    continue

                already_seen.append(range(begin, end))

                bef_text = sentence['text'][i:begin].split(' ')
                bef_mentions = ['O'] * len(bef_text)

                text.extend(bef_text)
                mentions.extend(bef_mentions)
                norm_mentions.extend(bef_mentions)

                mention_text = mention_text.split(' ')
                text.extend(mention_text)
                mentions.extend(['B-' + label] + ['I-' + label] * (len(mention_text)-1))
                norm_mentions.extend(['B-' + norm_label] + ['I-' + norm_label] * (len(mention_text)-1))
                i = end

        after_text = sentence['text'][end:].split(' ')
        text.extend(after_text)
        mentions.extend(['O'] * len(after_text))
        norm_mentions.extend(['O'] * len(after_text))

        all_text.extend([f'{t} NN {n} {m}\n' for t,m,n in zip(text, mentions, norm_mentions) if (t!='' and t!='\n')] + ['\n'])

    with open(text_file, 'w') as f:
        f.write('-DOCSTART- -X- -X- O\n\n')
        f.write(''.join(all_text))