In [149]:
import tqdm as tqdm
import os
import torch
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import datetime
from torchvision import transforms  
from torch.optim import lr_scheduler

class TreeNode(object):
     def __init__(self, n_eta=None, split_attribute_M= 0, split_value=0,  y_pred=None,
                        y_avg=None, posterior_var=None, posterior_mean=None,
                        left=None, right=None):
          self.n_eta = n_eta
          self.split_attribute_M = split_attribute_M   # internal node
          self.split_value = split_value  # internal node
          self.y_pred = y_pred  # leaf
          self.y_avg = y_avg # leaf
          self.posterior_var = posterior_var # leaf
          self.posterior_mean = posterior_mean  #leaf
          self.left = left
          self.right = right

def inner_decisions(node, leafinfo, k):
    decisions_pairs= []
    all_paths = []
    leafinfos = leafinfo.copy()
    if(node.left==None and node.right==None):
        return decisions_pairs, [(k, node.y_avg, leafinfos)], k
    else:
        decisions_pairs.append((k, node.split_attribute_M, node.split_value))

        leafinfos.append((k, -1))
        dec, paths, K =  inner_decisions(node.left, leafinfos, k+1)
        decisions_pairs.extend(dec)
        all_paths.extend(paths)

        leafinfos.pop()
        leafinfos.append((k, +1))
        dec, paths , K =  inner_decisions(node.right, leafinfos, K+1)
        decisions_pairs.extend(dec)
        all_paths.extend(paths)

    return decisions_pairs, all_paths, K

In [15]:
trees = []
with open('mydata.txt', 'r') as fp:
    for line in fp.readlines():
        vals = iter(line.split())

        def helper():
            val = next(vals)
            if val == '#':
                return None
            elems = val.split("@")
            node = TreeNode(n_eta = float(elems[0]), split_attribute_M = float(elems[1]), 
                            split_value = float(elems[2]), y_pred=float(elems[3]), 
                             y_avg= float(elems[4]),posterior_var= float(elems[5]), posterior_mean=float(elems[6]))
            node.left = helper()
            node.right = helper()
            return node
        trees.append(helper())
        

In [80]:
left = TreeNode(n_eta=1, left = TreeNode(n_eta=2) , right =TreeNode(n_eta=3, left= TreeNode(n_eta=4), right= TreeNode(n_eta=5)))
right = TreeNode(n_eta=6, left = TreeNode(n_eta=7) , right =TreeNode(n_eta=8, left= TreeNode(n_eta=9), right= TreeNode(n_eta=10)))
tree = TreeNode(n_eta=0, left=left, right=right)
tree

<__main__.TreeNode at 0x7f3a45f9caf0>

In [135]:
# test
decisions, paths, K = inner_decisions(tree, [], k = 0)
print(decisions)
print(paths)

[(0, 0, 0), (1, 0, 0), (3, 0, 0), (6, 0, 0), (8, 0, 0)]
[(2, None, [(0, -1), (1, -1)]), (4, None, [(0, -1), (1, 1), (3, -1)]), (5, None, [(0, -1), (1, 1), (3, 1)]), (7, None, [(0, 1), (6, -1)]), (9, None, [(0, 1), (6, 1), (8, -1)]), (10, None, [(0, 1), (6, 1), (8, 1)])]


In [156]:
class NN(nn.Module):
    def __init__(self, inner_decisions, paths, input_dim = 42, gamma1 = 1.0, gamma2 = 1.0):
        super(NN, self).__init__()
        self.gamma1 = gamma1
        self.gamma2 = gamma2

        self.dic = {}
        self.lin1 = nn.Linear(input_dim, len(inner_decisions))
        self.set_connections_1_layer(inner_decisions)

        self.lin2 = nn.Linear(len(inner_decisions), 1 + len(inner_decisions))
        self.avgs = []
        self.set_connections_2_layer(paths)

        self.lin3 = nn.Linear(1+len(inner_decisions), 1)
        self.set_connections_3_layer()


    def forward(self, x):
        mid = self.lin1(x)
        mid = self.lin2(torch.tanh(self,gamma1*mid))
        return self.lin2(torch.tanh(self,gamma2*mid))

    def set_connections_1_layer(self, inner_decisions):
        self.lin1.weight.data = torch.zeros(self.lin1.weight.shape)
        self.lin1.bias.data = torch.zeros(self.lin1.bias.shape)
        for i, elem in enumerate(inner_decisions):
            # node k to layer index
            self.dic[elem[0]] = i 

            self.lin1.weight.data[i, int(elem[1])] = 1
            self.lin1.bias[i] = elem[2]

    def set_connections_2_layer(self, paths):
        self.lin2.weight.data = torch.zeros(self.lin2.weight.shape)
        self.lin2.bias.data = torch.zeros(self.lin2.bias.shape)
        for i, elem in enumerate(paths):
            _, avg, tuples = elem
            self.avgs.append(avg)
            self.lin2.bias[i] = (-len(tuples) +1/2)
            for a in tuples:
                self.lin2.weight.data[i, self.dic[a[0]]] = a[1] 

    def set_connections_3_layer(self):
        self.lin3.weight.data = 0.5*torch.as_tensor(self.avgs)
        self.lin3.bias.data = self.lin3.weight.data.sum()


decisions, paths, _ =  inner_decisions(trees[44], [], 0)
NeN= NN(decisions, paths)


In [159]:
NeN(torch.zeros(42))

tensor(0.0883, grad_fn=<AddBackward0>)