In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

class SDT(nn.Module):
    """Fast implementation of soft decision tree in PyTorch.

    Parameters
    ----------
    input_dim : int
      The number of input dimensions.
    output_dim : int
      The number of output dimensions. For example, for a multi-class
      classification problem with `K` classes, it is set to `K`.
    depth : int, default=5
      The depth of the soft decision tree. Since the soft decision tree is
      a full binary tree, setting `depth` to a large value will drastically
      increases the training and evaluating cost.
    lamda : float, default=1e-3
      The coefficient of the regularization term in the training loss. Please
      refer to the paper on the formulation of the regularization term.
    use_cuda : bool, default=False
      When set to `True`, use GPU to fit the model. Training a soft decision
      tree using CPU could be faster considering the inherent data forwarding
      process.

    Attributes
    ----------
    internal_node_num_ : int
      The number of internal nodes in the tree. Given the tree depth `d`, it
      equals to :math:`2^d - 1`.
    leaf_node_num_ : int
      The number of leaf nodes in the tree. Given the tree depth `d`, it equals
      to :math:`2^d`.
    penalty_list : list
      A list storing the layer-wise coefficients of the regularization term.
    inner_nodes : torch.nn.Sequential
      A container that simulates all internal nodes in the soft decision tree.
      The sigmoid activation function is concatenated to simulate the
      probabilistic routing mechanism.
    leaf_nodes : torch.nn.Linear
      A `nn.Linear` module that simulates all leaf nodes in the tree.
    """

    def __init__(
            self,
            input_dim,
            output_dim,
            depth=5,
            lamda=1e-3,
            use_cuda=False):
        super(SDT, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.depth = depth
        self.lamda = lamda
        self.device = torch.device("cuda" if use_cuda else "cpu")

        self._validate_parameters()

        self.internal_node_num_ = 2 ** self.depth - 1
        self.leaf_node_num_ = 2 ** self.depth

        # Different penalty coefficients for nodes in different layers
        self.penalty_list = [
            self.lamda * (2 ** (-depth)) for depth in range(0, self.depth)
        ]

        # Initialize internal nodes and leaf nodes, the input dimension on
        # internal nodes is added by 1, serving as the bias.
        self.inner_nodes = nn.Sequential(
            nn.Linear(self.input_dim + 1, self.internal_node_num_, bias=False),
            nn.Sigmoid(),
        )

        self.leaf_nodes = nn.Linear(self.leaf_node_num_,
                                    self.output_dim,
                                    bias=False)

    def forward(self, X, is_training_data=False):
        _mu, _penalty = self._forward(X)
        y_pred = self.leaf_nodes(_mu)

        # When `X` is the training data, the model also returns the penalty
        # to compute the training loss.
        if is_training_data:
            return y_pred, _penalty
        else:
            return y_pred

    def _forward(self, X):
        """Implementation on the data forwarding process."""
        batch_size = X.size()[0]
        X = self._data_augment(X)

        path_prob = self.inner_nodes(X)
        path_prob = torch.unsqueeze(path_prob, dim=2)
        path_prob = torch.cat((path_prob, 1 - path_prob), dim=2)

        _mu = X.data.new(batch_size, 1, 1).fill_(1.0)
        _penalty = torch.tensor(0.0).to(self.device)

        # Iterate through internal nodes in each layer to compute the final path
        # probabilities and the regularization term.
        begin_idx = 0
        end_idx = 1

        for layer_idx in range(0, self.depth):
            _path_prob = path_prob[:, begin_idx:end_idx, :]

            # Extract internal nodes in the current layer to compute the
            # regularization term
            _penalty = _penalty + self._cal_penalty(layer_idx, _mu, _path_prob)
            _mu = _mu.view(batch_size, -1, 1).repeat(1, 1, 2)

            _mu = _mu * _path_prob  # update path probabilities

            begin_idx = end_idx
            end_idx = begin_idx + 2 ** (layer_idx + 1)

        mu = _mu.view(batch_size, self.leaf_node_num_)

        return mu, _penalty

    def _cal_penalty(self, layer_idx, _mu, _path_prob):
        """
        Compute the regularization term for internal nodes in different layers.
        """

        penalty = torch.tensor(0.0).to(self.device)

        batch_size = _mu.size()[0]
        _mu = _mu.view(batch_size, 2 ** layer_idx)
        _path_prob = _path_prob.view(batch_size, 2 ** (layer_idx + 1))

        for node in range(0, 2 ** (layer_idx + 1)):
            alpha = torch.sum(
                _path_prob[:, node] * _mu[:, node // 2], dim=0
            ) / torch.sum(_mu[:, node // 2], dim=0)

            coeff = self.penalty_list[layer_idx]

            penalty -= 0.5 * coeff * (torch.log(alpha) + torch.log(1 - alpha))

        return penalty

    def _data_augment(self, X):
        """Add a constant input `1` onto the front of each sample."""
        batch_size = X.size()[0]
        X = X.view(batch_size, -1)
        bias = torch.ones(batch_size, 1).to(self.device)
        X = torch.cat((bias, X), 1)

        return X

    def _validate_parameters(self):

        if not self.depth > 0:
            msg = ("The tree depth should be strictly positive, but got {}"
                   "instead.")
            raise ValueError(msg.format(self.depth))

        if not self.lamda >= 0:
            msg = (
                "The coefficient of the regularization term should not be"
                " negative, but got {} instead."
            )
            raise ValueError(msg.format(self.lamda))

In [None]:
class TorchDataset(torch.utils.data.Dataset):

        def __init__(self, *data, **options):
            
            n_data = len(data)
            if n_data == 0:
                raise ValueError("At least one set required as input")

            self.data = data
            means = options.pop('means', None)
            stds = options.pop('stds', None)
            self.transform = options.pop('transform', None)
            self.test = options.pop('test', False)
            
            if options:
                raise TypeError("Invalid parameters passed: %s" % str(options))
            
            if means is not None:
                assert stds is not None, "must specify both <means> and <stds>"

                self.normalize = lambda data: [(d - m) / s for d, m, s in zip(data, means, stds)]

            else:
                self.normalize = lambda data: data

        def __len__(self):
            return len(self.data[0])

        def __getitem__(self, idx):
            data = self.normalize([s[idx] for s in self.data])
            if self.transform:

                if self.test:
                    data = sum([[self.transform.test_transform(d)] * 2 for d in data], [])
                else:
                    data = sum([self.transform(d) for d in data], [])
                
            return data

In [None]:
#@title Synthetic data
def set_npseed(seed):
    np.random.seed(seed)


def set_torchseed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


#classification data

def data_gen_decision_tree(num_data=1000, dim=2, seed=0, w_list=None, b_list=None,vals=None, num_levels=2):        
    set_npseed(seed=seed)

    # Construct a complete decision tree with 2**num_levels-1 internal nodes,
    # e.g. num_levels=2 means there are 3 internal nodes.
    # w_list, b_list is a list of size equal to num_internal_nodes
    # vals is a list of size equal to num_leaf_nodes, with values +1 or 0
    num_internal_nodes = 2**num_levels - 1
    num_leaf_nodes = 2**num_levels
    stats = np.zeros(num_internal_nodes+num_leaf_nodes) #stores the num of datapoints at each node so at 0(root) all data points will be present

    if vals is None: #when val i.e., labels are not provided make the labels dynamically
        vals = np.arange(0,num_internal_nodes+num_leaf_nodes,1,dtype=np.int32)%2 #assign 0 or 1 label to the node based on whether its numbering is even or odd
        vals[:num_internal_nodes] = -99 #we put -99 to the internal nodes as only the values of leaf nodes are counted

    if w_list is None: #if the w values of the nodes (hyperplane eqn) are not provided then generate dynamically
        w_list = np.random.standard_normal((num_internal_nodes, dim))
        w_list = w_list/np.linalg.norm(w_list, axis=1)[:, None] #unit norm w vects
        b_list = np.zeros((num_internal_nodes))

    '''
    np.random.random_sample
    ========================
    Return random floats in the half-open interval [0.0, 1.0).

    Results are from the "continuous uniform" distribution over the
    stated interval.  To sample :math:`Unif[a, b), b > a` multiply
    the output of `random_sample` by `(b-a)` and add `a`::

        (b - a) * random_sample() + a
    '''

#     data_x = np.random.random_sample((num_data, dim))*2 - 1. #generate the datas in range -1 to +1
#     relevant_stats = data_x @ w_list.T + b_list #stores the x.wT+b value of each nodes for all data points(num_data x num_nodes) to check if > 0 i.e will follow right sub tree route or <0 and will follow left sub tree route
#     curr_index = np.zeros(shape=(num_data), dtype=int) #stores the curr index for each data point from root to leaf. So initially a datapoint starts from root but then it can go to right or left if it goes to right its curr index will become 2 from 0 else 1 from 0 then in next iteration from say 2 it goes to right then it will become 6

    data_x = np.random.standard_normal((num_data, dim))
    data_x /= np.sqrt(np.sum(data_x**2, axis=1, keepdims=True))
    relevant_stats = data_x @ w_list.T + b_list
    curr_index = np.zeros(shape=(num_data), dtype=int)
    
    for level in range(num_levels):
        nodes_curr_level=list(range(2**level - 1,2**(level+1)-1  ))
        for el in nodes_curr_level:
#             b_list[el]=-1*np.median(relevant_stats[curr_index==el,el])
            relevant_stats[:,el] += b_list[el]
        decision_variable = np.choose(curr_index, relevant_stats.T) #based on the curr index will choose the corresponding node value of the datapoint

        # Go down and right if wx+b>0 down and left otherwise.
        # i.e. 0 -> 1 if w[0]x+b[0]<0 and 0->2 otherwise
        curr_index = (curr_index+1)*2 - (1-(decision_variable > 0)) #update curr index based on the desc_variable
        

    bound_dist = np.min(np.abs(relevant_stats), axis=1) #finds the abs value of the minm node value of a datapoint. If some node value of a datapoint is 0 then that data point exactly passes through a hyperplane and we remove all such datapoints
    thres = threshold
    labels = vals[curr_index] #finally labels for each datapoint is assigned after traversing the whole tree

    data_x_pruned = data_x[bound_dist>thres] #to distingush the hyperplanes seperately for 0 1 labels (classification)
    #removes all the datapoints that passes through a node hyperplane
    labels_pruned = labels[bound_dist>thres]
    relevant_stats = np.sign(data_x_pruned @ w_list.T + b_list) #storing only +1 or -1 for a particular node if it is active or not
    nodes_active = np.zeros((len(data_x_pruned),  num_internal_nodes+num_leaf_nodes), dtype=np.int32) #stores node actv or not for a data

    for node in range(num_internal_nodes+num_leaf_nodes):
        if node==0:
            stats[node]=len(relevant_stats) #for root node all datapoints are present
            nodes_active[:,0]=1 #root node all data points active status is +1
            continue
        parent = (node-1)//2
        nodes_active[:,node]=nodes_active[:,parent]
        right_child = node-(parent*2)-1 # 0 means left, 1 means right 1 has children 3,4
        #finds if it is a right child or left of the parent
        if right_child==1:
            nodes_active[:,node] *= relevant_stats[:,parent]>0 #if parent node val was >0 then this right child of parent is active
        if right_child==0:
            nodes_active[:,node] *= relevant_stats[:,parent]<0 #else left is active
        stats = nodes_active.sum(axis=0) #updates the status i.e., no of datapoints active in that node (root has all active then gradually divided in left right)
    return ((data_x_pruned, labels_pruned), (w_list, b_list, vals), stats)

In [None]:
class Dataset_syn:
    def __init__(self, dataset, data_path='./DATA'):
        if dataset =="syn":
            self.X_train = train_data
            self.y_train = train_data_labels
            self.X_valid = vali_data
            self.y_valid = vali_data_labels
            self.X_test = test_data
            self.y_test = test_data_labels
        self.data_path = data_path
        self.dataset = dataset

**Change this cell to run different configs**

In [None]:
def onehot_coding(target, device, output_dim):
    """Convert the class labels into one-hot encoded vectors."""
    target_onehot = torch.FloatTensor(target.size()[0], output_dim).to(device)
    target_onehot.data.zero_()
    target_onehot.scatter_(1, target.view(-1, 1), 1.0)
    return target_onehot


# Define dictionaries
seed=365
num_levels=4
threshold = 0 #data seperation distance
output_dim=1


data_configs = [
    {"input_dim": 20, "num_data": 40000},
    {"input_dim": 100, "num_data": 60000},
    {"input_dim": 500, "num_data": 100000}
]

# Code block to run for each dictionary
for config in data_configs:
    input_dim = config["input_dim"]
    num_data = config["num_data"]

    
    
    ((data_x, labels), (w_list, b_list, vals), stats) = data_gen_decision_tree(
                                                dim=input_dim, seed=seed, num_levels=num_levels,
                                                num_data=num_data)
    seed_set=seed
    w_list_old = np.array(w_list)
    b_list_old = np.array(b_list)
    print(sum(labels==1))
    print(sum(labels==0))
    print("Seed= ",seed_set)
    num_data = len(data_x)
    num_train= num_data//2
    num_vali = num_data//4
    num_test = num_data//4
    
    train_data = data_x[:num_train,:]
    train_data_labels = labels[:num_train]

    vali_data = data_x[num_train:num_train+num_vali,:]
    vali_data_labels = labels[num_train:num_train+num_vali]

    test_data = data_x[num_train+num_vali :,:]
    test_data_labels = labels[num_train+num_vali :]

    # Parameters
    input_dim = input_dim    # the number of input dimensions
    output_dim = 2        # the number of outputs (i.e., # classes on MNIST)
    depth = 5              # tree depth
    lamda = 1e-3           # coefficient of the regularization term
    lr = 1e-3              # learning rate
    weight_decaly = 5e-4   # weight decay
    batch_size = 128       # batch size
    epochs = 500           # the number of training epochs
    log_interval = 100     # the number of batches to wait before printing logs
    use_cuda = True       # whether to use GPU

    # Model and Optimizer
    tree = SDT(input_dim, output_dim, depth, lamda, use_cuda).to(device)
    tree=tree.float()

    optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decaly)
    DATA_NAME = "syn"
    data = Dataset_syn(DATA_NAME)
    BATCH_SIZE=batch_size
    train_loader = DataLoader(TorchDataset(data.X_train, data.y_train), batch_size=BATCH_SIZE, num_workers=16, shuffle=True)
    valloader = DataLoader(TorchDataset(data.X_valid, data.y_valid), batch_size=BATCH_SIZE*2, num_workers=16, shuffle=False)
    test_loader = DataLoader(TorchDataset(data.X_test, data.y_test), batch_size=BATCH_SIZE*2, num_workers=16, shuffle=False)

    # Utils
    best_testing_acc = 0.0
    testing_acc_list = []
    training_loss_list = []
    criterion = nn.CrossEntropyLoss()
    device = torch.device("cuda" if use_cuda else "cpu")

    for epoch in range(epochs):

        # Training
        tree.train()
        for batch_idx, (data, target) in enumerate(train_loader):

            batch_size = data.size()[0]
            data, target = data.to(device), target.to(device)
            # Convert to torch.int64
            target = target.to(torch.int64)
            data=data.float()
            target_onehot = onehot_coding(target, device, output_dim)
            output, penalty = tree.forward(data, is_training_data=True)

            loss = criterion(output, target.view(-1))
            loss += penalty

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Print training status
            if batch_idx % log_interval == 0:
                pred = output.data.max(1)[1]
                correct = pred.eq(target.view(-1).data).sum()

                msg = (
                    "Epoch: {:02d} | Batch: {:03d} | Loss: {:.5f} |"
                    " Correct: {:03d}/{:03d}"
                )
                print(msg.format(epoch, batch_idx, loss, correct, batch_size))
                training_loss_list.append(loss.cpu().data.numpy())

        # Evaluating
        tree.eval()
        correct = 0.

        for batch_idx, (data, target) in enumerate(test_loader):

            batch_size = data.size()[0]
            data, target = data.to(device), target.to(device)
            data=data.float()
            output = F.softmax(tree.forward(data), dim=1)

            pred = output.data.max(1)[1]
            correct += pred.eq(target.view(-1).data).sum()

        accuracy = 100.0 * float(correct) / len(test_loader.dataset)

        if accuracy > best_testing_acc:
            best_testing_acc = accuracy

        msg = (
            "\nEpoch: {:02d} | Testing Accuracy: {}/{} ({:.3f}%) |"
            " Historical Best: {:.3f}%\n"
        )
        print(
            msg.format(
                epoch, correct,
                len(test_loader.dataset),
                accuracy,
                best_testing_acc
            )
        )
        testing_acc_list.append(accuracy)