In [1]:
import numpy as np
from math import ceil
from functools import reduce

In [2]:
class TreeNode:
    def __init__(self, length):
        self.val = np.zeros(length)
        self.left = None
        self.right = None
        
class LeafNode:
    def __init__(self, x, code):
        self.val = x
        self.code = code

In [4]:
def my_bin(x, digits):
    """
    Given any integer, return its binary form
    
    Parameters
    ----------
    x: int
        integer to be converted
    digits: positive integer
        digit of return binary number
    """
    s = bin(x).replace('0b', '')
    ext = digits - len(s)
    if ext > 0:
        s = "0"*ext + s
    return s

def sigmoid(x):
    """
    sigmiod function
    """
    return 1/(1+np.exp(-x))

def unique_words(corpus):
    """
    find unique words in the corpus
    
    Pamameters
    ----------
    corpus: list[list]
        each element of the list contains a list representing one sentence, where each element represents a word in a sentence
    """
    return list(set(reduce(lambda x, y: x + y, corpus)))

def vec_init(words, length):
    """
    randomly initialize the word vectors with given length
    """
    res = dict()
    for word in words:
        res[word] = np.random.rand(length)
    return res

def encode(words):
    """
    encoding the words with Huffman coding by initials
    """
    digits = ceil(np.log2(len(words)+2))
    words_se = words + ["<START>", "<END>"]
    codes = dict()
    for i, word in enumerate(words_se):
        codes[word] = my_bin(i, digits)
    return codes

def tree_init(words, length):
    """
    initialize the tree with leafnode representing words and non-leaf nodes containing vectors of given length.
    """
    digits = ceil(np.log2(len(words)+2))
    words_se = words + ["<START>", "<END>"]
    root = TreeNode(length)
    for i, word in enumerate(words_se):
        code = my_bin(i, digits)
        leaf = LeafNode(word, code)
        node = root
        if len(code)>1:
            for s in code[:-1]:
                if s == '0':
                    if not node.left:
                        node.left = TreeNode(length)
                    node = node.left
                else:
                    if not node.right:
                        node.right = TreeNode(length)
                    node = node.right
        if code[-1] == '0':
            node.left = leaf
        else:
            node.right = leaf
    return root

In [5]:
def hierarchical_softmax(corpus, length=10, window_size=2, learning_rate=0.01, epoches=10000):
    """
    Hierarchical softmax skip-gram model
    
    Parameters
    ----------
    corpus: list[list]
        each element of the list contains a list representing one sentence, where each element represents a word in a sentence
    length: positive int
        length of the vectors in result
    window_size: positive int
        size of context
    learning_rate: positive float
        learning rate of the gradient descenting algorithm
    """
    words = unique_words(corpus)
    v = vec_init(words, length)
    codes = encode(words)
    root = tree_init(words, length)
    digits = ceil(np.log2(len(words)))
    for epoch in range(epoches):
        for sentence in corpus:
            for i, w in enumerate(sentence):
                context = [] # find context
                for j in range(i-window_size, i+window_size+1):
                    if j == i:
                        continue
                    if j<0:
                        context.append('<START>')
                    elif j>=len(sentence):
                        context.append('<END>')
                    else:
                        context.append(sentence[j])
                for u in context:
                    code = codes[u]
                    node = root
                    e = np.zeros(length)
                    for j in range(digits-1):
                        q = sigmoid(v[w]@node.val)
                        g = learning_rate*(1-int(code[j])-q)
                        e += g*node.val
                        node.val += g*v[w]
                        if code[j]=="0":
                            node = node.left
                        else:
                            node = node.right
                    v[w] += e
    return v