In [1]:
import enum
import glob
import os
from hashlib import new
from pathlib import Path
import time

import functools

import numpy as np
import pandas as pd
import scipy
from flyingsquid.label_model import LabelModel as LMsquid
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from snorkel.labeling.model import LabelModel as LMsnorkel
from snorkel.labeling.model import MajorityLabelVoter

In [2]:
from sklearn.exceptions import UndefinedMetricWarning

def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [3]:
from typing import Tuple, Dict


class Document(object):

    def __init__(self, name, sentences):
        self.name = name
        self.sentences = sentences
        for s in sentences:
            s.document = self
        self.annotations = {i:{} for i in range(len(sentences))}
        self.props = {}
        self._text = None

    @property
    def text(self):
        if not self._text:
            t = ""
            for s in self.sentences:
                if len(t) != s.abs_char_offsets[0]:
                    t += ' ' * (s.abs_char_offsets[0] - len(t))
                t += s.text
            self._text = t
        return self._text
    
    def __repr__(self):
        return "Document({})".format(self.name)
        
        
class Sentence(object):

    def __init__(self, **kwargs):
        self.document = None
        self.__dict__.update(kwargs)
        self._text = None

    @property
    def text(self):
        if not self._text:
            txt = ""
            offset = self.abs_char_offsets[0]
            for i,w in enumerate(self.words):
                if len(txt) != self.abs_char_offsets[i] - offset:
                    txt += ' ' * (self.abs_char_offsets[i] - offset - len(txt))
                txt += w
            self._text = txt
        return self._text

    @property
    def position(self):
        return self.i
    
    @property
    def char_offsets(self):
        offset = self.abs_char_offsets[0]
        return [i - offset for i in self.abs_char_offsets]
               
    def __repr__(self):
        max_len = 25
        s = self.text.strip().replace("\n"," ")
        return "Sentence({})".format(
            s if len(s) < max_len else s[0:max_len] + '...'
        )
        

class Span(object):

    def __init__(self, char_start, char_end, sentence, attrib='words'):
        self.sentence   = sentence
        self.char_start = char_start
        self.char_end   = char_end
        self.attrib     = attrib
        self.props      = {}
        self.normalized = None
    
    @property
    def abs_char_start(self):
        return self.char_start + self.sentence.abs_char_offsets[0]

    @property
    def abs_char_end(self):
        return self.abs_char_start + (self.char_end - self.char_start)
    
    @property
    def text(self):
        return self.sentence.text[self.char_start:self.char_end + 1]

    def get_word_start(self):
        return self.char_to_word_index(self.char_start)

    def get_word_end(self):
        return self.char_to_word_index(self.char_end)

    def get_n(self):
        return self.get_word_end() - self.get_word_start() + 1

    def char_to_word_index(self, ci):
        i = None
        for i, co in enumerate(self.sentence.char_offsets):
            if ci == co:
                return i
            elif ci < co:
                return i-1
        return i

    def word_to_char_index(self, wi):
        return self.sentence.char_offsets[wi]

    def get_attrib_tokens(self, a):
        return self.sentence.__getattribute__(a)[self.get_word_start():self.get_word_end() + 1]
    
    def __repr__(self):
        return "Span({})".format(self.text.replace("\n"," "))
     
    def get_attrib_span(self, a, sep=" "):
        if a == 'words':
            return self.sentence.text[self.char_start:self.char_end + 1]
        else:
            return sep.join(self.get_attrib_tokens(a))

    def get_span(self, sep=" "):
        return self.get_attrib_span('words', sep)

    def __contains__(self, other_span):
        return other_span.abs_char_start >= self.abs_char_start and other_span.abs_char_end <= self.abs_char_end
    
    
class Candidate(object):
    """A collection of spans"""
    def __init__(self, spans):
        self.spans = spans


class Relation(object):

    def __init__(self,
                 type_name:str,
                 args: Dict[str, Span]) -> None:
        self.type_name = type_name
        self.args = args
        self.__dict__.update(args)

    def __iter__(self):
        for span in self.args.values():
            yield span

    def __getitem__(self, item):
        return list(self.args.values())[item]

    def __repr__(self):
        strs = [span.__repr__() for span in self.args.values()]
        return f"Relation[{self.type_name}]({','.join(strs)})"

    def __eq__(self, other):
        hashes = {name:span.__hash__() for name,span in self.args.items()}
        other = {name:span.__hash__() for name,span in other.args.items()}
        return hashes == other

    def __hash__(self):
        return hash(sum([s.__hash__() for s in self.args.values()]))

    @property
    def sentence(self):
        """We assume spans all live in the same sentence"""
        return self.__dict__[self.arg_names[0]].sentence



class Annotation(object):
    def __init__(self, doc_name: str,
                 span: Tuple[Tuple[int,int], ...],
                 etype: str,
                 text: str = None,
                 cid: str = None) -> None:
        """

        :param doc_name:
        :param span:
        :param etype:
        :param text:
        :param cid:
        """
        self.abs_char_start = span[0][0]
        self.abs_char_end = span[0][-1]

        self.doc_name = doc_name
        self.span = tuple([tuple(s) for s in span])
        self.text = text
        self.etype = etype
        self.cid = cid
        
    def __repr__(self):
        text = self.text.replace('\n',' ') + '|' if self.text else ''
        i,j = self.abs_char_start, self.abs_char_end
        sep = '...' if len(self.span) > 1 else '-'
        return f"Annotation[{self.etype}]({text}{i}{sep}{j})"

    @property
    def type(self):
        return self.etype

    def __hash__(self):
        return hash((self.etype, self.doc_name, self.span))

    def __eq__(self, other):
        return False if not isinstance(other, type(self)) else True

In [4]:
import re
import os
import gzip
import json
import glob
import numpy as np
import logging
import itertools
import collections

logger = logging.getLogger(__name__)


def parse_doc(d) -> Document:
    """
    Convert JSON into container objects. Transforming to
    Document/Sentence objects comes at ~13% overhead.
    """
    sents = [Sentence(**s) for s in d['sentences']]
    doc = Document(d['name'], sents)
    if 'metadata' in d:
        for key,value in d['metadata'].items():
            doc.props[key] = value
    return doc


class DocumentLoader:

    def __init__(self, fpath):
        self.fpath = fpath
        self.formatter = parse_doc

    def filelist(self):
        return glob.glob(f'{self.fpath}/*.json') \
            if os.path.isdir(self.fpath) else [self.fpath]

    def __iter__(self):
        for fpath in self.filelist():
            fopen = gzip.open if fpath.split(".")[-1] == 'gz' else open
            with fopen(fpath, 'rb') as fp:
                for line in fp:
                    yield self.formatter(json.loads(line))


def load_json_dataset(fpath,
                      tokenizer,
                      tag_fmt = 'IO',
                      contiguous_only = False):
    """Load JSON dataset and initialize sequence tagged labels.

    Parameters
    ----------
    fpath
        JSON file path
    tokenizer

    tag_fmt
        token tagging scheme with values in {'IO','IOB', 'IOBES'}
    """
    documents, entities = [], {}
    fopen = gzip.open if fpath.split(".")[-1] == 'gz' else open
    with fopen(fpath, 'rb') as fp:
        for line in fp:
            # initialize context objects
            d = json.loads(line)
            doc = Document(d['name'], [Sentence(**s) for s in d['sentences']])
            documents.append(doc)
            # load entities
            entities[doc.name] = set()
            if 'entities' not in d:
                continue
            for entity in d['entities']:
                del entity['abs_char_start']
                del entity['abs_char_end']
                if 'doc_name' not in entity:
                    entity['doc_name'] = doc.name
                anno = Annotation(**entity)
                if len(anno.span) > 1 and contiguous_only:
                    continue
                entities[doc.name].add(Annotation(**entity))

    return NerDocumentDataset(documents,
                              entities,
                              tag_fmt=tag_fmt,
                              tokenizer=tokenizer)


#################################################################################
#
#  Sequence Tag Creation
#
#################################################################################

def entity_tag(length, tag_fmt="IOB"):
    """
    IO, IOB, or IOBES (equiv. to BILOU) tagging

    :param tokens:
    :param is_heads:
    :param tag_fmt:
    :return:
    """
    tags = ['O'] * length
    tag_fmt = set(tag_fmt)

    if tag_fmt == set("IOB"):
        tags[0] = 'B'
        tags[1:] = len(tags[1:]) * "I"

    elif tag_fmt == set("IOBES") or tag_fmt == set("BILOU"):
        if len(tags) == 1:
            tags[0] = 'S'
        else:
            tags[0] = 'B'
            tags[1:-1] = len(tags[1:-1]) * "I"
            tags[-1:] = "E"

    elif tag_fmt == set("IO"):
        tags = ['I'] * len(tags)
    return tags


def map_sent_entities(document, entities, verbose=True):
    """
    Given (1) a document split into sentences and (2) a list of entities
    defined by absolute char offsets, map each entity to it's parent sentence.

    :param:
    :param:
    :return tuple of sentence index and tag,
    """
    errors = 0
    spans = []
    char_index = [s.abs_char_offsets[0] for s in document.sentences]

    for t in entities:
        position = None
        for i in range(len(char_index) - 1):
            if t.abs_char_start >= char_index[i] and t.abs_char_end <= char_index[i + 1]:
                position = i
                break

        if position == None and t.abs_char_start >= char_index[-1]:
            position = len(char_index) - 1

        if position == None:
            values = (document.name, t.abs_char_start, t.abs_char_end)
            if verbose:
                msg = f"{[t.text]} {t.span} {t.doc_name}"
                logger.warning(f"Cross-sentence mention {msg}")
            errors += 1
            continue
        try:
            shift = document.sentences[position].abs_char_offsets[0]
            span = document.sentences[position].text[t.abs_char_start - shift:t.abs_char_end - shift]
            spans.append((position, t, span))
        except Exception as e:
            logger.error(f'{e}')

    idx = collections.defaultdict(list)
    for i, entity, _ in spans:
        idx[i].append(entity)

    return idx, errors


def retokenize(sent, tokenizer, subword='##'):
    """
    Given a default tokenization, compute absolute character offsets for
    a new tokenization (e.g., BPE). By convention, wordpiece tokens are
    prefixed by ##.

    """
    tokens = []
    abs_char_offsets = []

    for i in range(len(sent.words)):
        toks = tokenizer.tokenize(sent.words[i])
        offsets = [sent.abs_char_offsets[i]]
        for w in toks[0:-1]:
            offsets.append(
                len(w if w[:len(subword)] != subword else w[len(subword):]) + offsets[-1]
            )
        abs_char_offsets.extend(offsets)
        tokens.extend(toks)

    return tokens, abs_char_offsets


def tokens_to_tags(sent,
                   entities,
                   tag_fmt='BIO',
                   tokenizer=None,
                   max_seq_len=512):
    """

    :param sent:
    :param entities:
    :param tag_fmt:
    :param tokenizer:
    :param max_seq_len:
    :return:
    """

    toks, abs_char_offsets = retokenize(sent, tokenizer) if tokenizer \
        else (sent.words, sent.abs_char_offsets)

    # truncate long sequences
    if len(toks) > max_seq_len - 2:
        toks = toks[0:max_seq_len - 2]
        abs_char_offsets = abs_char_offsets[0:max_seq_len - 2]

    # use original tokenization to assign token heads
    is_heads = [1 if i in sent.abs_char_offsets else 0 for i in abs_char_offsets]
    tags = ['O'] * len(toks)

    errs = 0
    for entity in entities:

        # currently we only support contiguous entity spans
        if len(entity.span) != 1:
            logger.warning(f"Non-contiguous entities not supported {entity} {sent.document.name}")
            continue

        head = entity.span[0]
        if head[0] in abs_char_offsets:
            start = abs_char_offsets.index(head[0])
            end = len(abs_char_offsets)

            for j, offset in enumerate(abs_char_offsets):
                if head[-1] > offset:
                    continue
                end = j
                break

            # tokenization error
            if is_heads[start] == 0:
                errs += 1
                logger.warning(f"Tokenization Error: Token is not a head token {entity} {sent.document.name}")
                continue

            tok_len = is_heads[start:end].count(1)
            head_tags = entity_tag(tok_len, tag_fmt=tag_fmt)
            head_tags = [f'{t}-{entity.type}' for t in head_tags]
            io_tags = ['O'] * len(toks[start:end])

            for i in range(len(io_tags)):
                if is_heads[start:end][i] == 1:
                    t = head_tags.pop(0)
                io_tags[i] = t

            tags[start:end] = io_tags

            # Error Checking: do spans match?
            s1 = ''.join([w if w[:2] != '##' else w[2:] for w in toks[start:end]]).lower()
            s2 = re.sub(r'''(\s)+''', '', entity.text).lower()


            if s1 != s2:
                if len(entity.span) == 1:
                    msg = f"{s1} != {s2}"
                    logger.error(f"Span does not match {msg}")
                errs += 1
        else:
            errs += 1
            logger.error(f"Tokenization Error: Token head not found in abs_char_offsets {entity}")

    return (toks, tags, is_heads), errs





#################################################################################
#
#  Datasets
#
#################################################################################


class NerDocumentDataset(object):
    """
    Document + Annotation objects
    entities are defined as abs char offsets per document
    """

    def __init__(self, documents: dict,
                 entities: dict,
                 tag_fmt: str = 'IO',
                 tokenizer=None) -> None:
        """
        Convert Document objects with a corresponding
        entity set into tagged sequences

        :param documents:
        :param entities:
        :param tag_fmt:
        :param tokenizer:
        """
        self.documents = documents
        self.entities = entities
        self.tag_fmt = tag_fmt
        self.tokenizer = tokenizer
        self.tag2idx = self._get_tag_index(entities, tag_fmt)

        self._init_sequences(documents)

    def _get_tag_index(self, entities, tag_fmt):
        """
        Given a collection of entity types, initialize an integer tag mapping
        e.g., B-Drug I-Drug O

        :param entities:
        :param tag_fmt:
        :return:
        """
        entity_types = {t.type for doc_name in entities for t in entities[doc_name]}
        tags = [t for t in list(tag_fmt) if t != 'O']
        tags = [f'{tag}-{etype}' for tag, etype in itertools.product(tags, entity_types)]
        tags = ['X', 'O', ] + tags
        return {t: i for i, t in enumerate(tags)}

    def __len__(self) -> int:
        return len(self.data)

    def tagged(self, idx):
        """
        Return tagged words
        :return:
        """
        X, _, _, Y, _, _ = self.__getitem__(idx)
        return X[1:-1], Y[1:-1]

    def _init_sequences(self, documents):
        """
        Transform Documents into labeled sequences.

        :param documents:
        :return:
        """
        self.data = []
        self.sentences = []
        num_errors, num_missing_heads, num_entities = 0, 0, 0

        for doc in documents:
            self.sentences.extend(doc.sentences)
            annotations = self.entities[doc.name] if doc.name in self.entities else {}
            num_entities += len(annotations)
            # tag sentences
            sent_entities, errs = map_sent_entities(doc, annotations)
            num_errors += errs

            for sentence in doc.sentences:
                entities = sent_entities[sentence.i] if sentence.i in sent_entities else []
                seqs, errs = tokens_to_tags(sentence, entities, self.tag_fmt, tokenizer=self.tokenizer)
                num_errors += errs

                x, y, is_heads = seqs
                if not (len(x) == len(y) == len(is_heads)):
                    print(seqs)

                self.data.append(seqs)

        assert len(self.data) == len(self.sentences)
        if num_errors:
            msg = f'Errors: Span Alignment: {num_errors}/{num_entities} ({num_errors / num_entities * 100:2.1f}%)'
            logger.warning(msg)

        print(f'Tagged Entities: {num_entities - num_errors}')

    def __getitem__(self, idx):

        toks, tags, is_heads = self.data[idx]

        words = [w for w in self.sentences[idx].words if w.strip()]
        words = self.sentences[idx].words
        # original tags (head words only)
        tags = [t for i, t in enumerate(tags) if is_heads[i] == 1]

        if len(words) != len(tags):
            print(len(words), len(tags))
            print(words)
            print(tags)
            print('-' * 50)

        words = ['[CLS]'] + words + ['[SEP]']
        toks = ['[CLS]'] + toks + ['[SEP]']
        tags = ['X'] + tags + ['X']

        X = self.tokenizer.convert_tokens_to_ids(toks)
        Y = [self.tag2idx[t] if h == 1 else self.tag2idx['X'] for t, h in zip(tags, is_heads)]

        return words, X, is_heads, tags, Y, len(Y)

In [5]:
import numpy as np


def mv(L, break_ties, abstain=-1):
    """Simple majority vote"""
    from statistics import mode
    y_hat = []
    for row in L:
        # get non abstain votes
        row = row[row != abstain]
        try:
            l = mode(row)
        except:
            l = break_ties
        y_hat.append(l)
    return np.array(y_hat).astype(np.int)

def smv(L, abstain=-1, uncovered=0):
    """Soft majority vote"""
    y_hat = []
    k = np.unique(L[L != abstain]).astype(int)
    k = list(range(min(k), max(k) + 1))
    for row in L:
        # get non abstain votes
        row = list(row[row != abstain])
        N = len(row)
        if not N:
            y_hat.append([1.0, 0])
        else:
            p = []
            for i in k:
                p.append(row.count(i) / N)
            y_hat.append(p)
    return np.array(y_hat).astype(np.float)

In [6]:
import itertools
import numpy as np
from typing import List, Set, Dict, Tuple, Pattern, Match, Iterable
import seqeval.metrics


def split_by_seq_len(X, X_lens) -> np.ndarray:
    """Given a matrix X of M elements, partition into N variable length
    sequences where [xi, ..., xN] lengths are defined by X_lens[i].

    This is used to partition a stacked matrix of words back into sentences.

    Parameters
    ----------
    X
    X_lens

    Returns
    -------

    """
    splits = [np.sum(X_lens[0:i]) for i in range(1, X_lens.shape[0])]
    return np.split(X, splits)


def convert_tag_fmt(
        seq: List[str],
        etype: str,
        tag_fmt: str = 'IOB') -> List[str]:
    """Convert between tagging schemes. This is a lossy conversion
    when converting to IO, i.e., mapping {IOB, IOBES} -> IO
    drops information on adjacent entities.

    IOB -> O B I I B I O
    IO  -> O I I I I I O
    IOB -> O B I I I I O

    Parameters
    ----------
    seq
    etype
    tag_fmt

    Returns
    -------

    """
    # TODO: Only works for IO -> {IOB, IOBES}
    assert set(seq).issubset(set('IO'))
    # divide into contiguous chunks
    chunks = [list(g) for _, g in itertools.groupby(seq)]
    # remap to new tagging scheme
    seq = list(itertools.chain.from_iterable(
        [tags if 'O' in tags else entity_tag(len(tags), tag_fmt)
         for tags in chunks]
    ))
    return [t if t == 'O' else f'{t}-{etype}' for t in seq]


def tokens_to_sequences(y_gold,
                        y_pred,
                        seq_lens,
                        idx2tag=None,
                        tag_fmt=None):
    """Convert token labels to sentences for sequence model evaluation.

    Parameters
    ----------
    y_gold
    y_pred
    seq_lens
    idx2tag
    tag_fmt

    Returns
    -------

    """
    idx2tag = {1: 'I', 0: 'O'} if not idx2tag else idx2tag
    y_gold_seqs = []
    for s in split_by_seq_len(y_gold, seq_lens):
        y = [idx2tag[i] for i in s]
        if tag_fmt is not None:
            y_hat = convert_tag_fmt(y, etype='ENTITY', tag_fmt='IOB')
        else:
            y_hat = y
        y_gold_seqs.append(y_hat)

    y_pred_seqs = []
    for s in split_by_seq_len(y_pred, seq_lens):
        # Sometimes -1 labels make it into evaluation due to Snorkel
        # label model. Just treat these as 'O'
        y = [idx2tag[i] if i in idx2tag else 'O' for i in s]
        if tag_fmt is not None:
            y_hat = convert_tag_fmt(y, etype='ENTITY', tag_fmt='IOB')
        else:
            y_hat = y
        y_pred_seqs.append(y_hat)

    return y_gold_seqs, y_pred_seqs


def score_sequences(y_true: List[List[int]],
                    y_pred: List[List[int]],
                    metrics: Set[str] = None) -> Dict[str, float]:
    """
    Sequence model evaluation using seqeval
    https://github.com/chakki-works/seqeval

    Parameters
    ----------
    y_gold
    y_pred

    Returns
    -------

    """
    scorers = {
        'accuracy': sklearn.metrics.accuracy_score,
        'precision': sklearn.metrics.precision_score,
        'recall': sklearn.metrics.recall_score,
        'f1': sklearn.metrics.f1_score
    }
    metrics = metrics if metrics is not None else scorers
    try:
        return {name: scorers[name](y_true, y_pred, average='macro') for name in metrics}
    except:
        return {name: 0.0 for name in metrics}


def eval_label_model(model, L, Y, seq_lens):

    idx2tag = {0: 'O', 1: 'I-X', 2: 'B-X'}

    # label model
    y_pred = model.predict(L)
    scores = score_sequences(*tokens_to_sequences(Y, y_pred, seq_lens, idx2tag=idx2tag))
    print('[Label Model]   {}'.format(
        ' | '.join([f'{m}: {v * 100:2.2f}' for m, v in scores.items()]))
    )

    # MV baseline
    y_pred = mv(L, 0)
    scores = score_sequences(*tokens_to_sequences(Y, y_pred, seq_lens, idx2tag=idx2tag))
    print('[Majority Vote] {}'.format(
        ' | '.join([f'{m}: {v * 100:2.2f}' for m, v in scores.items()]))
    )

In [7]:
import numpy as np
from itertools import product
from sklearn.metrics import (
    precision_score, recall_score,
    f1_score, accuracy_score,
    precision_recall_fscore_support
)




def sample_param_grid(param_grid, seed):
    """ Sample parameter grid

    :param param_grid:
    :param seed:
    :return:
    """
    rstate = np.random.get_state()
    np.random.seed(seed)
    params = list(product(*[param_grid[name] for name in param_grid]))
    np.random.shuffle(params)
    np.random.set_state(rstate)
    return params


def compute_metrics(y_gold, y_pred, average='binary'):
    """

    :param y_gold:
    :param y_pred:
    :param average:
    :return:
    """
    return {
        'accuracy': accuracy_score(y_gold, y_pred),
        'precision': precision_score(y_gold, y_pred, average=average),
        'recall': recall_score(y_gold, y_pred, average=average),
        'f1': f1_score(y_gold, y_pred, average=average)
    }


def grid_search_span(model_class,
                     model_class_init,
                     param_grid,
                     train=None,
                     dev=None,
                     n_model_search=5,
                     val_metric='f1',
                     seed=1234,
                     verbose=True):
    """Simple grid search helper function

    """
    L_train, Y_train = train if len(train) == 2 else (train[0], None)
    L_dev, Y_dev = dev

    # sample configs
    params = sample_param_grid(param_grid, seed)[:n_model_search]

    defaults = {'seed': seed}
    best_score, best_config = 0.0, None
    # set scoring mode based on the number of classes
    average = 'binary' if np.unique(Y_dev).shape[0] == 2 else 'micro'

    print(f"Grid search over {len(params)} configs")
    print(f'Averaging: {average}')

    for i, config in enumerate(params):
        print(f'[{i}] Label Model')
        config = dict(zip(param_grid.keys(), config))
        # update default params if not specified
        config.update({
            param: value for param, value in defaults.items() \
            if param not in config})

        model = model_class(**model_class_init)
        # fit (estimate class balance with Y_dev)
        model.fit(L_train, Y_dev, **config)

        y_pred = model.predict(L_dev)
        y_gold = Y_dev

        # Snorkel sometimes emits -1 predictions
        if -1 in y_pred:
            continue

        # only evaluate dev score
        mask = []
        for i in range(L_dev.shape[0]):
            if not np.all(L_dev[i] == -1):
                mask.append(i)

        mask = np.array(mask)
        metrics = compute_metrics(Y_dev[mask], model.predict(L_dev[mask]))

        msgs = []
        if not best_score or metrics[val_metric] > best_score[val_metric]:
            print(config)
            best_score = metrics
            best_config = config

            # mask uncovered data points
            mask = [i for i in range(L_train.shape[0]) \
                    if not np.all(L_train[i] == -1)]
            msgs.append(
                f'Coverage: {(len(mask) / L_train.shape[0] * 100):2.1f}%'
            )

            if Y_train is not None:
                # filter out candidate spans without gold labels
                y_mask = [i for i in range(len(Y_train)) if Y_train[i] != -1]
                mask = np.array(sorted(list(set(y_mask).intersection(mask))))
                metrics = compute_metrics(Y_train[mask],
                                          model.predict(L_train[mask]))
                msgs.append(
                    'TRAIN {}'.format(' | '.join(
                        [f'{m}: {v * 100:2.2f}' for m, v in metrics.items()])
                    )
                )

            msgs.append(
                'DEV   {}'.format(' | '.join(
                    [f'{m}: {v * 100:2.2f}' for m, v in best_score.items()]))
            )

        if verbose and msgs:
            print('\n'.join(msgs) + ('\n' + '-' * 80))

        if i % 50 == 0:
            print(f'[{i}] Label Model')

    # retrain best model
    if verbose:
        print('BEST')
        print(best_config)
    model = model_class(**model_class_init)
    model.fit(L_train, Y_dev, **best_config)
    return model, best_config


def grid_search(model_class,
                model_class_init,
                param_grid,
                train=None,
                dev=None,
                other_train=None,
                n_model_search=5,
                val_metric='f1',
                seed=1234,
                seq_eval=True,
                checkpoint_gt_mv=True,
                tag_fmt_ckpnt='BIO'):
    """Simple grid search helper function

    Parameters
    ----------
    model_class
    model_class_init
    param_grid
    train
    dev
    n_model_search
    val_metric
    seed
    seq_eval

    Returns
    -------

    """
    print(f"Using {'TOKEN' if not seq_eval else 'SEQUENCE'} dev checkpointing")
    if seq_eval:
        print(f"Using {tag_fmt_ckpnt} dev checkpointing")

    idx2tag = {0:'O', 1:'I-X', 2:'B-X'}

    L_train, Y_train, X_train_lens = train
    L_dev, Y_dev, X_dev_lens = dev

    # sample configs
    params = sample_param_grid(param_grid, seed)[:n_model_search]

    defaults = {'seed': seed}
    best_score, best_config = 0.0, None
    print(f"Grid search over {len(params)} configs")

    for i, config in enumerate(params):
        print(f'[{i}] Label Model')
        config = dict(zip(param_grid.keys(), config))
        # update default params if not specified
        config.update({param: value for param, value in defaults.items() if param not in config})

        model = model_class(**model_class_init)
        # fit (estimate class balance with Y_dev)
        # HACK for BIO tag evaluation
        if len(np.unique(Y_dev)) != 2:
            Y_dev_hat = np.array([0 if y == 0 else 1 for y in Y_dev])
        else:
            Y_dev_hat = Y_dev
        model.fit(L_train, Y_dev_hat, **config)

        y_pred = model.predict(L_dev)

        # set gold tags for evaluation
        if tag_fmt_ckpnt == 'IO':
            y_gold = np.array([0 if y == 0 else 1 for y in Y_dev])
        else:
            y_gold = Y_dev

        if -1 in y_pred:
            print("Label model predicted -1 (TODO: this happens inconsistently)")
            continue

        # score on dev set (token or sequence-level)
        if seq_eval:
            metrics = score_sequences(*tokens_to_sequences(y_gold, y_pred, X_dev_lens, idx2tag=idx2tag))
        else:
            # use internal label model scorer
            metrics = model.score(L=L_dev,
                                  Y=y_gold,
                                  metrics=['accuracy', 'precision', 'recall', 'f1'],
                                  tie_break_policy=0)

        # compare learned model against MV on same labeled dev set
        # skip if LM less than MV
        if checkpoint_gt_mv:
            if seq_eval:
                mv_y_pred = mv(L_dev, 0)
                mv_metrics = score_sequences(
                    *tokens_to_sequences(y_gold, mv_y_pred, X_dev_lens, idx2tag=idx2tag)
                )
            else:
                metrics = model.score(L=L_dev,
                                      Y=y_gold,
                                      metrics=['accuracy', 'precision', 'recall', 'f1'],
                                      tie_break_policy=0)

            if metrics[val_metric] < metrics[val_metric]:
                continue

        if not best_score or metrics[val_metric] > best_score[val_metric]:
            print(config)
            best_score = metrics
            best_config = config

            # print training set score if we have labeled data
            if np.any(Y_train):
                y_pred = model.predict(L_train)

                if tag_fmt_ckpnt == 'IO':
                    y_gold = np.array([0 if y == 0 else 1 for y in Y_train])
                else:
                    y_gold = Y_train

                if seq_eval:
                    metrics = score_sequences(*tokens_to_sequences(y_gold, y_pred, X_train_lens, idx2tag=idx2tag))
                else:
                    metrics = model.score(L=L_train,
                                          Y=y_gold,
                                          metrics=['accuracy', 'precision', 'recall', 'f1'],
                                          tie_break_policy=0)

                print('[TRAIN] {}'.format(' | '.join([f'{m}: {v * 100:2.2f}' for m, v in metrics.items()])))

            print('[DEV]   {}'.format(' | '.join([f'{m}: {v * 100:2.2f}' for m, v in best_score.items()])))
            print('-' * 88)

    # retrain best model
    print('BEST')
    print(best_config)
    model = model_class(**model_class_init)

    # HACK for BIO tag evaluation
    if len(np.unique(Y_dev)) != 2:
        Y_dev_hat = np.array([0 if y == 0 else 1 for y in Y_dev])
    else:
        Y_dev_hat = Y_dev
    model.fit(L_train, Y_dev_hat, **best_config)
    return model, best_config

In [8]:
def list2Nested(l, nested_length):
    return [l[i:i+nested_length] for i in range(0, len(l), nested_length)]

In [9]:
# Fetch UMLS ranks

sum_lf_p = '/mnt/nas2/results/Results/systematicReview/distant_pico/EBM_PICO_GT/lf_p_summary_train.csv'
sum_lf_i = '/mnt/nas2/results/Results/systematicReview/distant_pico/EBM_PICO_GT/lf_i_summary_train.csv'
sum_lf_o = '/mnt/nas2/results/Results/systematicReview/distant_pico/EBM_PICO_GT/lf_o_summary_train.csv'


def fetchRank(sum_lf_d):
    
    ranked_umls_coverage = dict()    
    umls_coverage_ = dict()
    
    data=pd.read_csv(sum_lf_d, sep='\t')
    
    for index, row in data.iterrows():
        if row[0].startswith('UMLS_fuzzy_'):
            umls_coverage_[row[0]] = row[3]
    
    umls_coverage_sorted = sorted(umls_coverage_.items(), key=lambda x: x[1], reverse=True)
    
    for i in umls_coverage_sorted:
        k = str(i[0]).split('_')[-1]
        ranked_umls_coverage[k] = i[1]

    return ranked_umls_coverage

ranksorted_p_umls = fetchRank(sum_lf_p)
ranksorted_i_umls = fetchRank(sum_lf_i)
ranksorted_o_umls = fetchRank(sum_lf_o)

In [10]:
# Partition LF's

def partitionLFs(umls_d):
    
    keys = list(umls_d.keys())

    partitioned_lfs = [ ]
    
    for i in range( 0, len(keys) ):

        if i == 0 or i == len(keys):
            if i == 0:
                partitioned_lfs.append( [keys] )
            if i ==len(keys):
                temp3 = list2Nested(keys, 1)
                partitioned_lfs.append( temp3 )
        else:
            temp1, temp2 = keys[:i] , keys[i:]
            temp3 = list2Nested( keys[:i], 1)
            temp3.append( keys[i:] )
            partitioned_lfs.append( temp3 )
    
    return partitioned_lfs


partitioned_p_umls = partitionLFs(ranksorted_p_umls)
partitioned_i_umls = partitionLFs(ranksorted_i_umls)
partitioned_o_umls = partitionLFs(ranksorted_o_umls)

In [11]:
import LMutils

# validation_labels   
# validation_labels_tui_pio2   
file = '/mnt/nas2/results/Results/systematicReview/distant_pico/EBM_PICO_GT/validation_labels_tui_pio2.tsv'
df_data = pd.read_csv(file, sep='\t', header=0)

In [12]:
Y_tokens = df_data['tokens']

In [13]:
Y_tokens = df_data['tokens']
#Y_p = df_data['p']
#Y_i = df_data['i']
#Y_o = df_data['o']
df_data_train, df_data_val = train_test_split(df_data, test_size=0.20, shuffle=False)

In [14]:
splits = ['train', 'dev']
X_sents = [
    df_data_train.tokens,
    df_data_val.tokens,
]

In [15]:
X_seq_lens = [
    np.array([len(str(s)) for s in X_sents[i]])
    for i,name in enumerate(splits)
]

In [16]:
X_seq_lens = [
    np.array( [ len(X_sents[i]) ] )
    for i,name in enumerate(splits)
]

In [17]:
# Read Candidate labels from multiple LFs
indir = '/mnt/nas2/results/Results/systematicReview/distant_pico/candidate_generation'
pathlist = Path(indir).glob('**/*.tsv')

tokens = []

lfs = dict()

for file in pathlist:

    k = str( file ).split('candidate_generation/')[-1].replace('.tsv', '').replace('/', '_')
    mypath = Path(file)
    if mypath.stat().st_size != 0:
        data = pd.read_csv(file, sep='\t', header=0)
    if len(tokens) == 0:
        tokens.extend( list(data.tokens) )
    
    sab = data.columns[-1]
    if len(list( data[sab] )) == 1354953:
        lfs[str(k)] = list( data[sab] )[:len(Y_tokens)]


print( 'Total number of tokens in validation set: ', len(tokens) )
print( 'Total number of LFs in the dictionary', len(lfs) )

Total number of tokens in validation set:  1354953
Total number of LFs in the dictionary 617


In [18]:
def lf_levels(umls_d:dict, pattern:str, picos:str):

    umls_level = dict()

    for key, value in umls_d.items():   # iter on both keys and values
        search_pattern = pattern + picos
        if key.startswith(search_pattern):
            k = str(key).split('_')[-1]
            umls_level[ k ] = value

    return umls_level

# Level 1: UMLS
umls_p = lf_levels(lfs, 'UMLS_fuzzy_', 'p')
umls_i = lf_levels(lfs, 'UMLS_fuzzy_', 'i')
umls_o = lf_levels(lfs, 'UMLS_fuzzy_', 'o')

# Level 2: non UMLS
nonumls_p = lf_levels(lfs, 'nonUMLS_fuzzy_', 'P')
nonumls_i = lf_levels(lfs, 'nonUMLS_fuzzy_', 'I')
nonumls_o = lf_levels(lfs, 'nonUMLS_fuzzy_', 'O')

# Level 3: DS
ds_p = lf_levels(lfs, 'DS_fuzzy_', 'P')
ds_i = lf_levels(lfs, 'DS_fuzzy_', 'I')
ds_o = lf_levels(lfs, 'DS_fuzzy_', 'O')

# Level 4: dictionary, rules, heuristics
heur_p = lf_levels(lfs, 'heuristics_direct_', 'P')
heur_i = lf_levels(lfs, 'heuristics_direct_', 'I')
heur_o = lf_levels(lfs, 'heuristics_direct_', 'O')

dict_p = lf_levels(lfs, 'dictionary_direct_', 'P')
dict_i = lf_levels(lfs, 'dictionary_direct_', 'I')
dict_o = lf_levels(lfs, 'dictionary_direct_', 'O')

In [19]:
def compare(s, t):
    return sorted(s) == sorted(t)

def getLFs(partition:list, umls_d:dict, seed_len:int):

    all_lfs_combined = []
    
    for lf in partition: # for each lf in a partition
        
        combine_here = [0] * seed_len

        for sab in lf:
            new_a = umls_d[sab]
            old_a = combine_here
            temp_a = []
            for o_a, n_a in zip(old_a, new_a):
                if compare([o_a, n_a] ,[-1, 1]) == True:
                    replace_a = max( o_a, n_a )
                    temp_a.append( replace_a )
                elif compare([o_a, n_a] ,[0, 1]) == True:
                    replace_a = max( o_a, n_a )
                    temp_a.append( replace_a )
                elif compare([o_a, n_a] ,[-1, 0]) == True:
                    replace_a = min( o_a, n_a )
                    temp_a.append( replace_a )
                else:
                    temp_a.append( o_a )

            combine_here = temp_a

        all_lfs_combined.append( combine_here )

    return all_lfs_combined

In [20]:
def grid_search(model_class,
                model_class_init,
                param_grid,
                train=None,
                dev=None,
                other_train=None,
                n_model_search=5,
                val_metric='f1',
                seed=1234,
                checkpoint_gt_mv=True,
                tag_fmt_ckpnt='BIO'):
    
    
    """Simple grid search helper function

    Parameters
    ----------
    model_class
    model_class_init
    param_grid
    train
    dev
    n_model_search
    val_metric
    seed

    Returns
    -------
    

    """
    
    
    L_train, Y_train = train
    L_dev, Y_dev = dev
    
    # sample configs
    params = sample_param_grid(param_grid, seed)[:n_model_search]
    
    defaults = {'seed': seed}
    best_score, best_config = 0.0, None
    print(f"Grid search over {len(params)} configs")
    
    for i, config in enumerate(params):
        print(f'[{i}] Label Model')
        config = dict(zip(param_grid.keys(), config))
        # update default params if not specified
        config.update({param: value for param, value in defaults.items() if param not in config})

        model = model_class(**model_class_init)
        
        
        Y_dev_hat = Y_dev
        model.fit(L_train, Y_dev_hat, **config)
        
        y_pred = model.predict(L_dev)
        
        # set gold tags for evaluation
        if tag_fmt_ckpnt == 'IO':
            y_gold = np.array([0 if y == 0 else 1 for y in Y_dev])
        else:
            y_gold = Y_dev
            
            
        if -1 in y_pred:
            print("Label model predicted -1 (TODO: this happens inconsistently)")
            continue
            
        # use internal label model scorer to score the prediction
        metrics = model.score(L=L_dev,
                              Y=y_gold,
                              metrics=['accuracy', 'precision', 'recall', 'f1', 'f1_macro'],
                              tie_break_policy=0)
        
    
        # compare learned model against MV on same labeled dev set
        # skip if LM less than MV
        if checkpoint_gt_mv:
            mv_metrics = model.score(L=L_dev,
                                  Y=y_gold,
                                  metrics=['accuracy', 'precision', 'recall', 'f1', 'f1_macro'],
                                  tie_break_policy=0)

            if metrics[val_metric] < mv_metrics[val_metric]:
                continue
                
        if not best_score or metrics[val_metric] > best_score[val_metric]:
            print(config)
            best_score = metrics
            best_config = config
            
            # print training set score if we have labeled data
            if np.any(Y_train):
                y_pred = model.predict(L_train)

                if tag_fmt_ckpnt == 'IO':
                    y_gold = np.array([0 if y == 0 else 1 for y in Y_train])
                else:
                    y_gold = Y_train

                metrics = model.score(L=L_train,
                                      Y=y_gold,
                                      metrics=['accuracy', 'precision', 'recall', 'f1', 'f1_macro'],
                                      tie_break_policy=0)

                print('[TRAIN] {}'.format(' | '.join([f'{m}: {v * 100:2.2f}' for m, v in metrics.items()])))

            print('[DEV]   {}'.format(' | '.join([f'{m}: {v * 100:2.2f}' for m, v in best_score.items()])))
            print('-' * 88)
            
            
    # retrain best model
    print('BEST')
    print(best_config)
    model = model_class(**model_class_init)
    
    
    Y_dev_hat = Y_dev
    model.fit(L_train, Y_dev_hat, **best_config)
    return model, best_config, best_score

In [34]:
def train(partitioned_d_umls, umls_d, non_umls_d, ds_d, heur_d, dict_d, df_data_train, df_data_val, picos, paramgrid):
   

    best_f1_macro = 0.0
    best_overall_model = ''
    best_overall_config = ''
    
    
    model_class_init = {
        'cardinality': 2, 
        'verbose': True
    }

    num_hyperparams = functools.reduce(lambda x,y:x*y, [len(x) for x in param_grid.values()])
    print("Hyperparamater Search Space:", num_hyperparams)
    n_model_search = num_hyperparams
    


    '''#########################################################################
    # Choosing the number of LF's from UMLS all
    #########################################################################'''
    
    for i, partition in enumerate(partitioned_d_umls):

        combined_lf = getLFs(partition, umls_d, len(Y_tokens))
        assert len(partition) == len(combined_lf)

        print( 'Total number of UMLS partitions: ', len(partition) )
        #print( 'Only UMLS: ', len(combined_lf) )
        combined_lf.extend( list(non_umls_d.values()) ) # Combine with level 2
        #print( 'Added nonUMLS: ', len(combined_lf) )
        combined_lf.extend( list(ds_d.values()) ) # Combine with level 3
        #print( 'Added DS: ', len(combined_lf) )
        combined_lf.extend( list(heur_d.values()) ) # Combine with level 4
        combined_lf.extend( list(dict_d.values()) ) # combine with level 4
        #print( 'Added ReGeX and rules: ', len(combined_lf) )

        L = np.array(combined_lf)
        #print('Full array before split before transpose: ', L.shape)
        L = np.transpose(L)
        #print('Full array before split: ', L.shape)
        L_train, L_val = train_test_split(L, test_size=0.20, shuffle=False)
        #print('Full train array after split: ', L_train.shape)
        #print('Full validation array after split: ', L_val.shape)

        Y_train = df_data_train[picos]
        Y_val = df_data_val[picos]

        #print( len( L_train ) )
        #print( len( L_val ) )
        #print( len( Y_train ) )
        #print( len( Y_val ) )
        
        best_model, best_config, best_score = grid_search(LMsnorkel, 
                                           model_class_init, 
                                           paramgrid,
                                           train = (L_train, Y_train),
                                           dev = (L_val, Y_val),
                                           n_model_search=n_model_search, 
                                           val_metric='f1_macro', 
                                           seed=1234,
                                           tag_fmt_ckpnt='IO')
        
        if best_score['f1_macro'] > best_f1_macro:
            best_f1_macro = best_score['f1_macro']
            best_overall_model = best_model
            best_overall_config = best_config
            
        
        print('Best overall macro F1 score: ', best_f1_macro)
        print('Best overall configuration: ', best_overall_config)

In [35]:
param_grid = {
    'lr': [0.001, 0.0001],
    'l2': [0.001, 0.0001],
    'n_epochs': [50, 100, 200, 600, 700, 1000, 2000],
    'prec_init': [0.6, 0.7, 0.8, 0.9],
    'optimizer': ["adamax", "adam", "sgd"],
    'lr_scheduler': ['constant'],
}

In [None]:
train(partitioned_p_umls, umls_p, nonumls_p, ds_p, heur_p, dict_p, df_data_train, df_data_val, 'p', paramgrid = param_grid)

Hyperparamater Search Space: 336
Total number of UMLS partitions:  1
Grid search over 336 configs
[0] Label Model
{'lr': 0.001, 'l2': 0.0001, 'n_epochs': 200, 'prec_init': 0.8, 'optimizer': 'adam', 'lr_scheduler': 'constant', 'seed': 1234}
[TRAIN] accuracy: 87.89 | precision: 0.00 | recall: 0.00 | f1: 0.00 | f1_macro: 46.78
[DEV]   accuracy: 87.70 | precision: 0.00 | recall: 0.00 | f1: 0.00 | f1_macro: 46.72
----------------------------------------------------------------------------------------
[1] Label Model
[2] Label Model
[3] Label Model
{'lr': 0.001, 'l2': 0.0001, 'n_epochs': 700, 'prec_init': 0.9, 'optimizer': 'adamax', 'lr_scheduler': 'constant', 'seed': 1234}
[TRAIN] accuracy: 87.05 | precision: 32.19 | recall: 6.29 | f1: 10.52 | f1_macro: 51.77
[DEV]   accuracy: 86.93 | precision: 33.68 | recall: 6.46 | f1: 10.84 | f1_macro: 51.89
----------------------------------------------------------------------------------------
[4] Label Model
{'lr': 0.001, 'l2': 0.001, 'n_epochs': 600

Total number of UMLS partitions:  2
Grid search over 336 configs
[0] Label Model
{'lr': 0.001, 'l2': 0.0001, 'n_epochs': 200, 'prec_init': 0.8, 'optimizer': 'adam', 'lr_scheduler': 'constant', 'seed': 1234}
[TRAIN] accuracy: 87.95 | precision: 74.77 | recall: 0.76 | f1: 1.50 | f1_macro: 47.54
[DEV]   accuracy: 87.78 | precision: 74.09 | recall: 0.98 | f1: 1.93 | f1_macro: 47.71
----------------------------------------------------------------------------------------
[1] Label Model
[2] Label Model
[3] Label Model
{'lr': 0.001, 'l2': 0.0001, 'n_epochs': 700, 'prec_init': 0.9, 'optimizer': 'adamax', 'lr_scheduler': 'constant', 'seed': 1234}
[TRAIN] accuracy: 87.04 | precision: 32.14 | recall: 6.32 | f1: 10.57 | f1_macro: 51.79
[DEV]   accuracy: 86.91 | precision: 33.44 | recall: 6.50 | f1: 10.88 | f1_macro: 51.91
----------------------------------------------------------------------------------------
[4] Label Model
{'lr': 0.001, 'l2': 0.001, 'n_epochs': 600, 'prec_init': 0.6, 'optimizer'

[327] Label Model
[328] Label Model
[329] Label Model
[330] Label Model
[331] Label Model
[332] Label Model
[333] Label Model
[334] Label Model
[335] Label Model
BEST
{'lr': 0.001, 'l2': 0.0001, 'n_epochs': 2000, 'prec_init': 0.6, 'optimizer': 'adamax', 'lr_scheduler': 'constant', 'seed': 1234}
Best overall macro F1 score:  0.5623404464548678
Best overall configuration:  {'lr': 0.001, 'l2': 0.0001, 'n_epochs': 2000, 'prec_init': 0.6, 'optimizer': 'adamax', 'lr_scheduler': 'constant', 'seed': 1234}
Total number of UMLS partitions:  3
Grid search over 336 configs
[0] Label Model
{'lr': 0.001, 'l2': 0.0001, 'n_epochs': 200, 'prec_init': 0.8, 'optimizer': 'adam', 'lr_scheduler': 'constant', 'seed': 1234}
[TRAIN] accuracy: 87.89 | precision: 49.11 | recall: 0.90 | f1: 1.77 | f1_macro: 47.66
[DEV]   accuracy: 87.74 | precision: 58.27 | recall: 1.15 | f1: 2.26 | f1_macro: 47.86
----------------------------------------------------------------------------------------
[1] Label Model
[2] Label M

In [None]:
train(partitioned_i_umls, umls_i, nonumls_i, ds_i, heur_i, dict_i, df_data_train, df_data_val, 'i', paramgrid = param_grid)

In [None]:
train(partitioned_o_umls, umls_o, nonumls_o, ds_o, heur_o, dict_o, df_data_train, df_data_val, 'o', paramgrid = param_grid)