In [1]:
import torch
from treelib import Tree, Node

In [3]:

root_id = 0
left_id = 1
middle_id = 2
right_id = 3
botom_left_id = 4
botom_right_id = 5

features = 5

N = 6


In [4]:
target = torch.randn(N, features, dtype=torch.float)
target

tensor([[-1.4633,  0.6768, -1.4848, -2.2866,  1.0574],
        [ 0.0288, -0.2464,  0.6284,  0.3609,  0.1281],
        [-1.8932,  1.3271,  0.4608, -0.4228, -0.2085],
        [ 0.2593,  0.4280, -0.0259,  0.0154, -2.1332],
        [ 0.8526,  1.1054, -2.1104,  0.6521,  0.5760],
        [-0.5097, -1.1057, -1.5116, -0.6783,  0.6962]])

In [5]:
tree = Tree()
tree.create_node(data=root_id, identifier=root_id)
tree.create_node(data=left_id, identifier=left_id, parent=root_id)
tree.create_node(data=middle_id, identifier=middle_id, parent=root_id)
tree.create_node(data=right_id, identifier=right_id, parent=root_id)
tree.create_node(data=botom_left_id, identifier=botom_left_id, parent=left_id)
tree.create_node(data=botom_right_id, identifier=botom_right_id, parent=left_id)

Node(tag=5, identifier=5, data=5)

In [6]:
W_code = torch.randn(N, features, dtype=torch.float, requires_grad=True)
B_code = torch.ones(N, features, dtype=torch.float, requires_grad=True)

In [7]:
def getChildrenIDs(id):
    return [ c.data for c in tree.children(id) ]

In [9]:
#
# Combine the nodes
# p = Wcomb1 * vec(P) + Wcomb2 * tanh( SUM( Wcode * vec(Ci) + Bcode)  )
#
class CombinationLayer:
    
    def __init__(self, tree, weight_code, weight_left_combination, weight_right_combination):
        self.tree = tree
        self.tree_size = tree.size()    
        self.weight_code = weight_code
        self.weight_left_combination = weight_left_combination
        self.weight_right_combination = weight_right_combination
        
    def forward(self):
        """
        Run the forward pass
        """
        return self._combine_tree(self.tree)
    
    #def weight_code(self, node_id):
    #    v = ( (self.tree_size - node_id)  / (self.tree_size - 1) ) * self.weight_code_left + ( (node_id - 1) / (self.tree_size - 1) ) * self.weight_code_right
    #    # print(v)
    #    return v
        
    def _combine(self, root_id, children_ids):
        """
        Combine one node
        """
        
        # If there's no children, return vec(p)
        if not children_ids: 
            return data[root_id]

        left_comb = self.weight_left_combination[root_id] * data[root_id]
        tot_sum = torch.zeros(N, features)
        for child_id in children_ids:
            tot_sum[child_id] = self.weight_code[child_id] * data[child_id] + B_code[child_id]
        tanh_sum = torch.tanh(tot_sum.sum(dim=0))
        right_comb = self.weight_right_combination[root_id] * tanh_sum
        total = left_comb + right_comb

        return total
        
    def _combine_tree(self, tree):
        """
        Combine the tree
        """
        tot = torch.zeros(N, features)
        for n in tree.all_nodes():
            tot[n.data] = self._combine(n.data, getChildrenIDs(n.data))
        return tot 

In [None]:
w_code = torch.randn(N, features, dtype=torch.float)

W_comb1 = torch.randn(N, features, dtype=torch.float, requires_grad=True)
W_comb2 = torch.randn(N, features, dtype=torch.float, requires_grad=True)

In [None]:
learning_rate = 1e-4

layer = CombinationLayer(tree, w_code, W_comb1, W_comb2)


for t in range(1000):
    
    y_pred = layer.forward()
    
    loss = (y_pred - target).pow(2).sum() 
    
    # if (t % 10 == 0): print(t, loss.item())
    
    loss.backward()
    
    with torch.no_grad():
        
        W_comb1 -= learning_rate * W_comb1.grad
        W_comb2 -= learning_rate * W_comb2.grad
        B_code -= learning_rate * B_code.grad
        
        W_comb1.grad.zero_()
        W_comb2.grad.zero_()
        B_code.grad.zero_()