In [40]:
import numpy as np
import torch
from functools import reduce
import matplotlib.pyplot as plt
%matplotlib inline

In [41]:
np.random.seed(1943)
torch.manual_seed(1943)

<torch._C.Generator at 0x2071efaefd0>

In [56]:
from sklearn import datasets
from sklearn.preprocessing import OneHotEncoder    
import pandas as pd
iris = datasets.load_iris()
features=iris.data 
labels=iris.target 
of=OneHotEncoder(sparse=False).fit(labels.reshape(-1,1))  
labels=of.transform(labels.reshape(-1,1))


In [58]:
_x = features[:, 2:4]
_y = labels
d = _x.shape[1]

x = torch.from_numpy(_x.astype(np.float32))
y = torch.from_numpy(np.argmax(_y, axis=1))


In [26]:
def torch_kron_prod(a, b):
    res = torch.einsum('ij,ik->ijk', [a, b])
    res = torch.reshape(res, [-1, np.prod(res.shape[1:])])
    return res

In [27]:
def torch_bin(x, cut_points, temperature=0.1):
    # x is a N-by-1 matrix (column vector)
    # cut_points is a D-dim vector (D is the number of cut-points)
    # this function produces a N-by-(D+1) matrix, each row has only one element being one and the rest are all zeros
    D = cut_points.shape[0]
    W = torch.reshape(torch.linspace(1.0, D + 1.0, D + 1), [1, -1])
    cut_points, _ = torch.sort(cut_points)  # make sure cut_points is monotonically increasing
    b = torch.cumsum(torch.cat([torch.zeros([1]), -cut_points], 0),0)
    h = torch.matmul(x, W) + b
    res = torch.exp(h-torch.max(h))
    res = res/torch.sum(res, dim=-1, keepdim=True)
    return h

In [28]:
def nn_decision_tree(x, cut_points_list, leaf_score, temperature=0.1):
    # cut_points_list contains the cut_points for each dimension of feature
    leaf = reduce(torch_kron_prod,
                  map(lambda z: torch_bin(x[:, z[0]:z[0] + 1], z[1], temperature), enumerate(cut_points_list)))
    return torch.matmul(leaf, leaf_score)

In [61]:
num_cut = [1, 1]  # "Petal length" and "Petal width"
num_leaf = np.prod(np.array(num_cut) + 1)
print(num_leaf)
num_class = 3
cut_points_list = [torch.rand([i], requires_grad=True) for i in num_cut]
print(cut_points_list)
leaf_score = torch.rand([num_leaf, num_class], requires_grad=True)
print(leaf_score)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cut_points_list + [leaf_score], lr=0.01)

4
[tensor([0.7900], requires_grad=True), tensor([0.9033], requires_grad=True)]
tensor([[0.0293, 0.0781, 0.6693],
        [0.9873, 0.4241, 0.8040],
        [0.5723, 0.2832, 0.1554],
        [0.4604, 0.7285, 0.9584]], requires_grad=True)


In [30]:
for i in range(1000):
    optimizer.zero_grad()
    y_pred = nn_decision_tree(x, cut_points_list, leaf_score, temperature=0.1)
    loss = loss_function(y_pred, y)
    loss.backward()
    optimizer.step()
    if i % 200 == 0:
        print(loss.detach().numpy())
print('error rate %.2f' % (1-np.mean(np.argmax(y_pred.detach().numpy(), axis=1)==np.argmax(_y, axis=1))))

12.724256
0.49909937
0.20397107
0.13334799
0.12001412
error rate 0.04
