<h1><center>Contextual Forest demo</center></h1>

## What is this project?
This project contains the code behind the Contextual Forest model, that I designed during [my bachelor's thesis](http://t.ly/bBNd). It includes a demo notebook with
code snippets showing the main functionality and informally exposes the ideas behind the model as well as 
the key points discussed in my bachelor's thesis regarding the task of Word Sense Disambiguation (WSD).

## Introduction

Contextual Forests, as many other language models, operate under the assumption that the meaning of a word in a sentence can
be fully determined by it's context (i.e. the other words in the sentence). If one thinks about it for a second that seems like
a very reasonable thing to assume considering that this is essentially what humans do when they communicate with each other.

Unfortunately, in many situations context has proven to be rather difficult to figure out using even the most advanced techniques. 
This is a direct consequence of the knowledge-based component of the context, let's illustrate this with an example.

Suppose we are given the folowing sentence:

<center><i>"The best Queen songs redefined rock"</i></center>

Now let's focus our attention in the words _Queen_ and _rock_. For us humans, if familiar with 70's rock, it's an incredibly easy and almost automatic process to recognise that we are talking about the british rock band _Queen_ and that _rock_ is a music genre. For a language model instead, this could present
a far more challenging situation since both _Queen_ and _rock_ as separate words can be referring to a considerable number of different things (for example a female monarch and a solid aggregate of minerals). 

So how do language models deal with this problem? One of the key points
is that this words appear <u>together</u> in the same sentence. Probabilistically speaking, I could be talking about the 
Queen Elizabeth II of the United Kingdom but it's highly unlikely since I'm also talking about some concept 
represented by the word _rock_ which has not been statistically associated with the context of Queen Elizabeth II. 
From this point it's only natural to wonder how a language model assigns this probabilities and where do the statistics come from. 
Depending on the answer to this question we can roughly classify language models in two types:

* **Context-free models** such as Word2Vec or GloVe, are based on creating a 1:1 mapping between words and vectors (which are usually referred to as word embeddings). Although specifics in the implementation may be different, the general idea for this models consists in training a neural network over a large corpus of text to get a representation that captures semantic properties. For example if we consider the vectors $k, w, m$ corresponding to the words "king", "woman" and "man" respectively, then the vector $q = (k - m) + w$ it's very close to the one assigned to the word "queen". This can be extremely useful in many situations but it's rather complicated to solve the WSD problem with this kind of models since the embedding for every word is unique and independent from the context.

<center><img src="./imgs/glove.png" alt="drawing" width="500"/></center>

* **Dynamic-embedding models** like ELMo, BERT and RoBERTa are based on an architecture call "transformers" and since their first appearance arround 2018, this kind of models have been at the top of the NLP world. The difference between this models and the context-free ones is that in this kind of models the embedding of a word is actually assigned **depending** on the other words in the context. They also need huge corpus for training and considerably more computational power than the context-free models but the results in the WSD problem are significantly better than in the latest. Let's illustrate this with an example

<center><img src="./imgs/green.png" alt="drawing" width="500"/></center>

Although dynamic-embedding models have a significantly better performance in the WSD problem, there is a key aspect in the training process that they share with the context-free models; they both need a big text corpus. So having this big corpus for training means that with enough computational power and some fancy architectures one can build a pretty decent model entirely based on statistics infered from the corpus (like training masked language models or predicting the next word like the GPT family) and with no real understanding of the language. So what are we looking at when we are face to face with the state of the art models? Language understanding or statistical inference? The answer is somewhere in between, it's clear that language models have mastered the syntax of language but they have a long way to go to be able to understand the subtilities of language (see [this article](https://aclanthology.org/P19-1459.pdf) from Niven and Kao for an example).


While I was thinking about how people solve the WSD problem on a daily basis in conversations, I came to the conclusion that the disambiguation process could not be memory-based. In the previous example, we don't know that _Queen_ is a british rock band because one night while we were discussing music in a bar we heard a friend of a friend using the words _Queen_ and _rock_ in the same sentence. We know that _Queen_ is a british rockband and _rock_ is a music genre because the context of each word is, probabilistically speaking, the most consistent option considering the rest of the words in the sentence. Specifically, when we read the first three words (_The best Queen_) and until the next word, in our head, Queen can mean many things, but the moment we read the word _songs_ we think "Okay we are talking about music so this _Queen_ must be the rock band". This kind of association is exactly the idea behind Contextual Forests.

## Contextual Forest

So I wanted to built a disambiguation system through semantical connections using the context from possible meanings of the words in a sentence. I thought that using context for making connections to find common ground between word meanings sounded a lot like spanning nodes in searching algorithms over graphs so I decided to model this expansion process as different Trees (one for each possible meaning) trying to connect with each other. Given a word, to make the model work I needed some kind of context about specific meanings organized in a structure similar to a graph so I decided to use Wikipedia. Unfortunatly, Wikipedia only provides articles about nouns (objects / people / events ...) so due to this impass and time restrictions the implemented version of Contextual Forests only works for disambiguating nouns. Nevertheless, the process is easily scalable if one finds another resource for covering more words.

<center><img src="./imgs/model_idea.png" alt="drawing" width="500"/></center>

### Step by step
 
 The first step is identifying nouns in a sentence and finding all possible meanings for every one. This could have been a problem but forunatly Wikipedia has pages specifically designed for this task:

<center><img src="./imgs/disambiguation.png" alt="drawing" width="500"/></center>


After finding possible meanings, it's worth mentioning that a Wikipedia page has a rather large number of links that recursively expanded can yield to a computationally infeasable search problem. So the next thing the algorithm needed was a "relevance function" that could evaluate which links to expand in order to quickly find a connection between Trees (this is commonly known in computer science as heuristic function). This heuristic function needed to represent how close two Wikipedia articles are. At first, I thought in just finding common links between articles, but as it turns out, it's not uncommon for two Wikipedia pages that are not related at all to have common links:

```python
    >>> from ContextualForest import wiki
    >>> A = wiki.page('potato').links.keys()
    >>> B = wiki.page('Microsoft').links.keys()
    >>> len(A & B)
    12
```

What would eventually end up working was the common links over the *relevant links* and this was a crucial step. 

So how do we define a relevant link? For figuring that out, we need to define what words are relevant in a Wikipedia page. I identified relevant words with the ones uniformly distributed over the text (note that this relevant words don't need to be repeated many times but just a few times every once in a while). 

<center><img src="./imgs/metrics.png" alt="drawing" width="500"/></center>

After that I tried to define link relevance as the average relevance of the words composing the link's title. This prooved to be innefective as a relevance metric because using a non-weighted average can be pretty sensitive to outliers resulting in a bias towards links that have the most relevant word as part of the title. To solve this issue I studied the distribution of the relevance score which I discovered that follows Zipf's distribution.

<center><img src="./imgs/zipf.png" alt="drawing" width="500"/></center>

Instead of computing link relevance as a simple average, I defined it as the inverse image of an average over rank positions capturing the decreasing factor on the scoring values to correct the bias.


<center><img src="./imgs/link_relevance.png" alt="drawing" width="500"/></center>

At this point, using the Trees we have all the necessary tools to disambiguate context with this non-supervised technique

```python
    >>> from ContextualForest import contextual_forest
    >>> fr = fr = contextual_forest("Queen redefined rock with their songs")
    >>> for word, node in fr.dic.items():
            possible_meanings = len(disambiguation(word))
            if not possible_meanings:
                #no disambiguation page
                possible_meanings = 1
            print(f"Word: {word}\t possible meanings: {possible_meanings}\n Choosen: {node.page.text[:100]} ...")
    
```
Which prints

Word: rock	 possible meanings: 49


 Choosen: Rock music is a broad genre of popular music that originated as "rock and roll" in the United States ...


Word: songs	 possible meanings: 1


 Choosen: A song is a musical composition intended to be performed by the human voice. This is often done at d ...


Word: queen	 possible meanings: 40


 Choosen: Queen are a British rock band formed in London in 1970. Their classic line-up was Freddie Mercury (l ...

 


Note that the algortihm is far from perfect and sometimes can fail to disambiguate context but this approach proves that context can be disambiguated by mining some specific pieces of information in graph-based structures (Knowledge Graphs) mimicing a reasoning process instead of training with millions of examples and learning the statistical information needed for disambiguation from them. This idea to put effort in the *how* models learn instead of *how much* data we have to scale our model capabilities is, in my opinion, something that definetly deserves consideration in order to build models that could ultimatly reason like we humans do.


In [None]:
pip install torch==1.4.0

In [67]:
import flair

from flair.embeddings import TransformerWordEmbeddings
from flair.data import Sentence

# init embedding
embedding = TransformerWordEmbeddings('bert-large-uncased')

sentence1 = Sentence('the grass is green.')
sentence2 = Sentence('she has green eyes and light brown hair.')
sentence3 = Sentence('when players hit the ball onto the green, they use a putter.')
# embed words in sentence
embedding.embed(sentence1)
embedding.embed(sentence2)
embedding.embed(sentence3)

[Sentence: "when players hit the ball onto the green , they use a putter ."   [− Tokens: 14]]

In [69]:
from scipy.spatial.distance import cosine

cosine_sim =  lambda x,y : 1 - cosine(x,y)

green1 = sentence1[3].embedding
green2 = sentence2[2].embedding
green3 = sentence3[7].embedding

print(f"Distance 1-3:{cosine_sim(green1, green3)}") # color - golf
print(f"Distance 1-2:{cosine_sim(green1, green2)}") # color - color

Distance 1-3:0.5156915783882141
Distance 1-2:0.7878293395042419


In [None]:
from ContextualForest import *

In [None]:
A = wiki.page('Gazpacho').links.keys()
B = wiki.page('Microsoft').links.keys()
len(A & B)

In [None]:
text = wiki.page('bohemian rhapsody').text

In [None]:
d = set_relevance(*stem_text(text))

In [None]:
fr = contextual_forest("Queen redefined rock with their songs")

In [66]:
# Print results
for word, node in fr.dic.items():
    possible_meanings = len(disambiguation(word))
    if not possible_meanings:
        #no disambiguation page
        possible_meanings = 1
    print(f"Word: {word}\t possible meanings: {possible_meanings}\n Choosen: {node.page.text[:100]} ...")

Word: rock	 possible meanings: 49
 Choosen: Rock music is a broad genre of popular music that originated as "rock and roll" in the United States ...
Word: songs	 possible meanings: 1
 Choosen: A song is a musical composition intended to be performed by the human voice. This is often done at d ...
Word: queen	 possible meanings: 40
 Choosen: Queen are a British rock band formed in London in 1970. Their classic line-up was Freddie Mercury (l ...


In [None]:
f1 = Forest(["Cat food","Dog"])
"""
while not f.Q.empty():
    sim,u,v,t1,t2 = f.Q.get()
    print("sim:{},node:{},node:{}".format(sim,u.page.title,v.page.title))"""
f1.disambiguate()
f1.recover_words()

In [None]:
wiki.page('shot (disambiguation)').links.keys()

In [None]:
disambiguation('shot')

In [None]:
N = 10000000
d = {a : None for a in range(N)}


In [None]:
%time
N in d

In [None]:
%time
d[N]

In [None]:
"""class Forest():
    """ Implementation of the contextual forest main data structure for 
        context-based semantic disambiguation.

        Attributes
        ----------
        words : list of str
            The keywords to disambiguate
        trees : dict
            Mapping between ContextualForest.Tree objects and the root nodes
            associated with that tree objects.
        dic : dict
            Mapping between the words provided to disambiguate and root nodes
            associated with tree objects formed in the disambiguation process.
        connections : dict
            Where keys are combinations of two possible trees and values are 
            boolean indicating wether or not the pair of trees is connected.
        Q : queue.PriorityQueue
            The priority queue structure to manage connections and expansion order
        Methods
        -------
        disambiguate()
            Performs the disambiguation process (forward).
        recover_words()
            Recovers words synsets once the disambiguation process has completed.
    """
    def __init__(self,words):
        self.words = None
        self.trees = {}
        self.dic = {}
        for word in words:
            self.dic[word] = None
            self.trees[Tree(word)]= None
        self.tree_combs = itertools.combinations(self.trees,2) #possible pairs of trees
        self.connections = {pair:False for pair in self.tree_combs}
        self.Q = pq()
        for tree1,tree2 in self.tree_combs:
            for u in tree1.to_expand:
                for v in tree2.to_expand:
                    sim = -u.similarity(v) #negative because pq orders naturally
                    if sim == 0:
                        continue
                    self.Q.put((sim,u,v,tree1,tree2)) # (similarity, node_1, node_2, tree_1, tree_2)

    def disambiguate(self):
        """ Performs the forward disambiguation process expanding the nodes till
            all trees are connected by a path.
        """
        #while the are connections to check or some tree is not connected
        while not all(self.connections.values()) and not self.Q.empty(): 
            _,u,v,t1,t2 = self.Q.get()
            key = (t1,t2) if (t1,t2) in self.connections else (t2,t1) #depends on itertools
            while self.connections[key]:
                #while key belongs to an already connectyed tree, pop from pq
                _,u,v,t1,t2 = self.Q.get()
                key  =  (t1,t2) if (t1,t2) in self.connections else (t2,t1)
            #expand both nodes
            news_t1 = t1.expand_node(u)
            news_t2 = t2.expand_node(v)
            #check for intersection
            if any([True if n in t2.to_expand else False for n in t1.to_expand]):
                self.connections[key] = True
                if self.trees[t1] == None:
                    self.trees[t1] = u.root
                if self.trees[t2] == None:
                    self.trees[t2] = v.root
            else: #no connection, add new nodes
                #expansion:
                for u in news_t1:
                    for v in news_t2:
                        sim = -u.similarity(v)
                        if sim == 0:
                            continue
                        self.Q.put((sim,u,v,t1,t2))
                           
    def recover_words(self):
        """ Recovers the Nodes associated with the disambiguation of every word provided
            in the instanziation of the class and stores it in the `dic` attribute.
        """
        found = False
        for tree,node in self.trees.items():
            for link,page in node.page.links.items():
                if tree.word.lower() == link.lower():
                    self.dic[tree.word] =  Node(page,1)
                    found = True
            if not found:
                connection = node
                while connection.depth != 1:
                    connection = connection.parent
                self.dic[tree.word] = connection
            found = False
    """