In [2]:
from typing import List, Tuple, Callable

import spacy
from spacy.tokens import Doc, Span

from thinc.types import Floats2d, Ints1d, Ragged, cast
from thinc.api import Model, Linear, chain, Logistic

import numpy as np

import re

from itertools import product

In [3]:
# define global variable for length of a spacy word vector
VECTOR_LENGTH = 300

# stop words (common english words like 'the', 'a' which aren't very informative)
STOP_WORDS = spacy.load('en_core_web_sm').Defaults.stop_words



AttributeError: module 'srsly' has no attribute 'msgpack_encoders'

In [3]:
@spacy.registry.architectures.register('microsp_host_rel_model.v1')
def create_relation_model(
    create_instance_tensor: Model[List[Doc], Floats2d],
    classification_layer: Model[Floats2d, Floats2d]
) -> Model[List[Doc], Floats2d]:
    model = chain(create_instance_tensor, classification_layer)
    model.attrs['get_instances'] = create_instance_tensor.attrs['get_instances']
    return model

In [4]:
@spacy.registry.architectures.register('microsp_host_rel_classification_layer.v1')
def create_classification_layer(
    n0: int = None, nI: int = None
) -> Model[Floats2d, Floats2d]:
    return chain(Linear(n0=n0, nI=nI), Logistic())

In [5]:
def instance_forward(
    model: Model[List[Doc], Floats2d],
    docs: List[Doc],
    is_train: bool
) -> Tuple[Floats2d, Floats2d]:
    tok2vec = model.get_ref('tok2vec')
    pooling = model.get_ref('pooling')  # default pool = mean pool
    get_instances = model.attrs['get_instances']
    
    all_instances = [get_instances(doc) for doc in docs]
    tokvecs, bp_tokvects = tok2vec(docs, is_train)  # tok2vec is trained
    
    # for making instance vectors, then join vertically into matrix
    ents = []
    lengths = []
    
    for doc_nr, (instances, tokvec) in enumerate(zip(all_instances, tokvecs)):
        token_indices = []
        for instance in instances:
            for ent in instance:
                token_indices.extend([i for i in range(ent.start, ent.end)])
                lengths.append(ent.end - ent.start)
            
            ents.append(tokvec[token_indices])
            
        lengths = cast(Ints1d, model.ops.asarray(lengths, dtype='int32'))
        entities = Ragged(model.ops.flatten(ents), lengths)
        pooled, bp_pooled = pooling(entities, is_train)
        
        # Reshape so that pairs of rows are concatenated
        relations = model.ops.reshape2f(pooled, -1, pooled.shape[1] + 2)
        
        def backprop(d_relations: Floats2d) -> List[Doc]:
            d_pooled = model.ops.reshape2f(d_relations, d_relations.shape[0] * 2, -1)
            d_ents = bp_pooled(d_pooled).data
            d_tokvecs = []
            ent_index = 0
            for doc_nr, instances in enumerate(all_instances):
                shape = tokvecs[doc_nr].shape
                d_tokvec = model.ops.alloc2f(*shape)
                count_occ = model.ops.alloc2f(*shape)
                for instance in instances:
                    for ent in instance:
                        d_tokvec[ent.start : ent.end] += d_ents[ent_index]
                        count_occ[ent.start : ent.end] += 1
                        ent_index += ent.end - ent.start
                d_tokvec /= count_occ + 0.00000000001
                d_tokvecs.append(d_tokvec)

            d_docs = bp_tokvecs(d_tokvecs)
            return d_docs

    return relations, backprop

In [161]:
@spacy.registry.misc("microsp_host_rel_instance_generator.v1")
def create_instances() -> Callable[[Doc], List[np.ndarray]]:
    def get_instances(doc: Doc) -> List[np.ndarray]:
        def get_doc_context_vector(doc: Doc) -> np.ndarray:
            entity_positions = set()
            for ent in doc.ents:
                entity_positions.add(ent.start)
                entity_positions.add(ent.end - 1)
                
            non_entity_tokens = [tok.vector for i, tok in enumerate(doc) if i not in entity_positions \
                                and tok.text not in STOP_WORDS]
            
            # return average word vector of all context tokens
            return np.average(non_entity_toks, axis=0)
        
        # create vector representing all non-entity tokens in the document
        context_vector = get_doc_context_vector(doc)
    
    return get_instances