In [154]:
from typing import List, Tuple, Callable

import spacy
from spacy.tokens import Doc, Span

from thinc.types import Floats2d, Ints1d, Ragged, cast
from thinc.api import Model, Linear, chain, Logistic

import numpy as np

import re

from itertools import product

In [77]:
# define global variable for length of a spacy word vector
VECTOR_LENGTH = 300

In [3]:
@spacy.registry.architectures.register('microsp_host_rel_model.v1')
def create_relation_model(
    create_instance_tensor: Model[List[Doc], Floats2d],
    classification_layer: Model[Floats2d, Floats2d]
) -> Model[List[Doc], Floats2d]:
    model = chain(create_instance_tensor, classification_layer)
    model.attrs['get_instances'] = create_instance_tensor.attrs['get_instances']
    return model

In [4]:
@spacy.registry.architectures.register('microsp_host_rel_classification_layer.v1')
def create_classification_layer(
    n0: int = None, nI: int = None
) -> Model[Floats2d, Floats2d]:
    return chain(Linear(n0=n0, nI=nI), Logistic())

In [5]:
def instance_forward(
    model: Model[List[Doc], Floats2d],
    docs: List[Doc],
    is_train: bool
) -> Tuple[Floats2d, Floats2d]:
    tok2vec = model.get_ref('tok2vec')
    pooling = model.get_ref('pooling')  # default pool = mean pool
    get_instances = model.attrs['get_instances']
    
    all_instances = [get_instances(doc) for doc in docs]
    tokvecs, bp_tokvects = tok2vec(docs, is_train)  # tok2vec is trained
    
    # for making instance vectors, then join vertically into matrix
    ents = []
    lengths = []
    
    for doc_nr, (instances, tokvec) in enumerate(zip(all_instances, tokvecs)):
        token_indices = []
        for instance in instances:
            for ent in instance:
                token_indices.extend([i for i in range(ent.start, ent.end)])
                lengths.append(ent.end - ent.start)
            
            ents.append(tokvec[token_indices])
            
        lengths = cast(Ints1d, model.ops.asarray(lengths, dtype='int32'))
        entities = Ragged(model.ops.flatten(ents), lengths)
        pooled, bp_pooled = pooling(entities, is_train)
        
        # Reshape so that pairs of rows are concatenated
        relations = model.ops.reshape2f(pooled, -1, pooled.shape[1] + 2)
        
        def backprop(d_relations: Floats2d) -> List[Doc]:
            d_pooled = model.ops.reshape2f(d_relations, d_relations.shape[0] * 2, -1)
            d_ents = bp_pooled(d_pooled).data
            d_tokvecs = []
            ent_index = 0
            for doc_nr, instances in enumerate(all_instances):
                shape = tokvecs[doc_nr].shape
                d_tokvec = model.ops.alloc2f(*shape)
                count_occ = model.ops.alloc2f(*shape)
                for instance in instances:
                    for ent in instance:
                        d_tokvec[ent.start : ent.end] += d_ents[ent_index]
                        count_occ[ent.start : ent.end] += 1
                        ent_index += ent.end - ent.start
                d_tokvec /= count_occ + 0.00000000001
                d_tokvecs.append(d_tokvec)

            d_docs = bp_tokvecs(d_tokvecs)
            return d_docs

    return relations, backprop

In [161]:
@spacy.registry.misc("microsp_host_rel_instance_generator.v1")
def create_instances() -> Callable[[Doc], List[np.ndarray]]:
    def get_instances(doc: Doc) -> List[np.ndarray]:
        def get_abbreviated_name(species: str) -> str:
            if ' ' not in species:
                return species
            
            # remove any trailing parenthesized text from species names (ex: names of discovering scientists),
            # as these are typically irrelevant
            species = re.sub(r' ?\(.+\)$', '', species).split()
            return f"{' '.join([s[0] + '.' for s in species[:len(species) - 1]])} {species[-1]}"
        
        def are_aliased_names(ent_1: str, ent_2: str) -> bool:
            # entity names are substrings, identical to each other, or one name is the abbreviation for
            # the other
            #
            # note to self: maybe move second condition to another if statement
            abbreviated_name_1 = get_abbreviated_name(ent_1).strip()
            abbreviated_name_2 = get_abbreviated_name(ent_2).strip()
            if (ent_1 in ent_2 or ent_2 in ent_1) or \
                abbreviated_name_1 in abbreviated_name_2 or \
                abbreviated_name_2 in abbreviated_name_1:
                return True
            
            return False
        
        # in case a microsporidia/host entity spans multiple tokens, take the mean word vector
        # ent.as_doc().vector
        microsp = [ent for ent in doc.ents if ent.label_ == 'MICROSPORIDIA']
        hosts = [ent for ent in doc.ents if ent.label_ == 'HOST']

        # keep track of which microsporidia / host names are synonymous to each other
        # i.e: abbreviations, same species mentioned >1 time
        # ex: 'Amblyospora hasseri' = 'A. hasseri' = 'Amblyospora hasseri (1)' (if all are same entity types)
        microsp_aliases = []
        host_aliases = []
        ents_assigned_aliases = []
        
        for m in microsp:
            if m not in ents_assigned_aliases:
                ents_assigned_aliases.append(m)
                # see are_aliased_names function for definition of aliased species names
                aliases = [m_ for m_ in microsp if m_ not in ents_assigned_aliases and are_aliased_names(m.text, m_.text)]
    
                if aliases:
                    # pool all entity vectors then take average of all aliased species entitiy vectors, for the
                    # vector representation for a particular microsporidia
                    microsp_aliases.append(np.mean(np.vstack([m.vector] + [a.vector for a in aliases]), axis=0))
                    
                    # add all aliased entities to list of entities that we've already seen
                    ents_assigned_aliases.extend(aliases)
                else:
                    # if no aliased names for this species, just take the pooled vector for the entity
                    microsp_aliases.append(m.vector)
        
        for h in hosts:
            if h not in ents_assigned_aliases:
                ents_assigned_aliases.append(h)
                aliases = [h_ for h_ in hosts if h_ not in ents_assigned_aliases and are_aliased_names(h.text, h_.text)]
    
                if aliases:
                    ents_assigned_aliases.extend(aliases)
                    host_aliases.append(np.mean(np.vstack([h.vector] + [a.vector for a in aliases]), axis=0))
                else:
                    host_aliases.append(h.vector)
        
        # 1 instance = average embedding for a pair of microsporidia and host pooled vectors
        return [(arr[0] + arr[1]) / 2 for arr in product(microsp_aliases, host_aliases)]
        
    return get_instances