In [1]:
import numpy as np
import pandas as pd
import jax.numpy as jnp
from jax import grad, jit, vmap

In [2]:
class Node:
    """
    Class for Node with split
    """
    def __init__(
        self,
        length_history,
        length_observation,
        status=None,
        root=None,
        index=None,
        max_depth=None,
        
    ):
        
        # Savind data about length of vectors
        self.length_history = length_history
        self.length_observation = length_observation
        
        # Saving the hyperparameters
        self.max_depth = max_depth if max_depth else 5
        
        # Setting index of elemnet
        self.index = index if index else 0
        self.depth = int(np.floor(np.log2(self.index + 2)))
        self.status = status if status else "node"
        
    
        # Init root with None
        self.root = root
        
        # Init split weights and split bias
#         self.w = np.random.sample(length_history + length_observation)
#         self.b = np.random.rand()
#         self.wb = np.random.sample(length_history + length_observation + 1)

        
    def init_node(self, index):
        # Init childs with Leafs
        self.l_child = \
            Node(self.length_history, \
                self.length_observation, \
                status='leaf',\
                root=self.root,\
                index=2*index + 1, \
                max_depth=self.max_depth)
        self.r_child = \
            Node(self.length_history, \
                self.length_observation, \
                status='leaf',\
                root=self.root,\
                index=2*index + 2,\
                max_depth=self.max_depth)
        
        
    
    def predict(self, 
                observations: list,
                next_history: list,
                pred_index,
                node_parameters):

        current = self
        index = 0
        product_path = 1
        l1_norm_sum = 0
        
        # проходим по дереву пока не упираемся в leaf
        while current.status == 'node':
            # use formula for probabilty
#             probabilty = self.sigmoid(current.b + jnp.dot(current.w, jnp.concatenate((observations, history), axis=None)))
            probabilty = \
                self.sigmoid(jnp.dot(node_parameters[index], \
                jnp.concatenate((observations, jnp.tanh(next_history[pred_index]), 1), axis=None)))
            l1_norm_sum += jnp.sum(jnp.square(node_parameters[index]))
#             print(float(probabilty))
#             print(probabilty)
#             probabilty = 0
            
            if probabilty > 0.5:
                product_path = jnp.multiply(product_path, probabilty)
                current = current.l_child
                index = 2*index + 1
            else:
                product_path = jnp.multiply(product_path, 1 - probabilty)
                current = current.r_child
                index = 2*index + 2
                
            
                
        # возвращаем предсказанное действие, следующую историю, и предсказания относительно
        # будущих наблюдений        
#         return (jnp.tanh(action_parameters[index]), index, jnp.tanh(pred_observation[index]), product_path)
        return index, product_path, l1_norm_sum               
    
    def sigmoid(self, value):
        return 1 / (1 + jnp.exp(-value))



In [3]:
def init_root(length_history, length_observation):
    root = Node(length_history,length_observation)
    root.root = root
    root.init_node(root.index)
    return root

In [4]:
root = init_root(length_history=10,length_observation=10)

In [5]:
root.l_child = Node(length_history=10,length_observation=10, root=root,index=root.l_child.index )
root.l_child.init_node(root.l_child.index)

In [6]:
root.r_child = Node(length_history=10,length_observation=10, root=root,index=root.r_child.index )
root.r_child.init_node(root.r_child.index)

In [7]:
def softmax(vector):
    sum_ = jnp.sum(jnp.exp(vector))
    return jnp.exp(vector)/sum_

In [8]:
length_history = 10
length_observation = 10
max_depth = 3
node_parameters = 2 * np.random.sample((2**max_depth - 1, length_history + length_observation + 1)) - 1
action_parameters = 2 * np.random.sample((2**max_depth - 1, 2)) - 1
next_history = 2 * np.random.sample((2**max_depth - 1, length_history)) - 1
pred_observation = 100 * np.random.sample((2**max_depth -1, length_observation))

In [9]:
action_parameters[0]

array([-0.87727088,  0.88825722])

In [10]:
node_parameters[0]

array([ 0.4317521 ,  0.27145478,  0.28385271,  0.93766197, -0.74674141,
       -0.63809734,  0.42233557,  0.04765333, -0.52090349,  0.22621116,
        0.72803344,  0.16298103,  0.03745352, -0.33811173,  0.55082402,
       -0.46434567,  0.3277548 ,  0.68219854, -0.05521256, -0.47591801,
        0.17982955])

In [11]:
initial_history = next_history[0]
first_observation = 2 * np.random.sample(length_observation) - 1
index, product_path, l1_norm_sum = \
                root.predict(first_observation,
                next_history, 
                0,
                node_parameters)





In [12]:
l1_norm_sum

DeviceArray(12.587061, dtype=float32)

In [13]:
action_parameters[index]

array([ 0.27325662, -0.51127618])

In [14]:
def to_actions(vector):
    actions = {0: 'not fraud', 1: 'fraud'}
#     print(jnp.argmax(vector))
    return actions[int(jnp.argmax(vector))]

In [15]:
def process(element):
#     print(np.log2(element.index + 1))
    print('\t'*int(np.floor(np.log2(element.index+1))) + element.status, end='')
    if element.status!='node':
        print('\t'*int(np.floor(np.log2(element.index+1))), end='')
        to_actions(action_parameters[element.index])
        print(to_actions(action_parameters[element.index]))
        print()
#         print(action_parameters[element.index])
    if element.status=='node':
        print(': some split')
        process(element.l_child)
        process(element.r_child)
    

In [16]:
process(root)

node: some split
	node: some split
		leaf		fraud

		leaf		fraud

	node: some split
		leaf		not fraud

		leaf		fraud



In [17]:
def loss_function(node_parameters,
                    action_parameters,
                    next_history,
                    pred_observation, 
                    pred_index, 
                    observation, 
                    actions, 
                    delta1,
                    delta2):
    
#     predicted_actions, next_history_predicted_index, pred_obs, product_path = \
    index, product_path, l1_norm_sum = \
        root.predict(observation,
                    next_history,
                    pred_index,
                    node_parameters)
    predicted_actions = action_parameters[index]
    
    pred_obs = pred_observation[index]
    general_loss = - jnp.dot(product_path, jnp.sum(actions * jnp.log(softmax(predicted_actions))))
    additional_loss = delta1 * jnp.sum(jnp.square(observation - pred_obs))
    return general_loss + additional_loss + l1_norm_sum

In [18]:
actions = jnp.array([0, 1])
delta1 = 0.001
delta2 = 0.01
pred_index = 0
loss_function(node_parameters,
                action_parameters,
                next_history,
                pred_observation,
                pred_index,
                first_observation,
                actions,
                delta1,
                delta2)

DeviceArray(36.221844, dtype=float32)

In [19]:
grad_loss_node = grad(loss_function, 0)


In [20]:
grad_loss_node(node_parameters,
                action_parameters,
                next_history,
                pred_observation,
                0,
                first_observation,
                actions,
                delta1,
                delta2)


DeviceArray([[ 0.86659646,  0.56431055,  0.58287954,  1.9171195 ,
              -1.5728126 , -1.3332481 ,  0.869006  ,  0.0472103 ,
              -1.0875912 ,  0.49178597,  1.4297829 ,  0.37399054,
               0.13316715, -0.7445415 ,  1.038306  , -1.0014179 ,
               0.6938259 ,  1.3136001 , -0.14549223, -0.8808924 ,
               0.2613884 ],
             [ 0.        ,  0.        ,  0.        ,  0.        ,
               0.        ,  0.        ,  0.        ,  0.        ,
               0.        ,  0.        ,  0.        ,  0.        ,
               0.        ,  0.        ,  0.        ,  0.        ,
               0.        ,  0.        ,  0.        ,  0.        ,
               0.        ],
             [ 0.36414924, -1.8371958 , -0.60694915, -1.8038634 ,
              -1.3916855 , -0.9426239 , -1.4154408 ,  0.82592237,
              -0.5726112 ,  0.60350806, -1.4007857 ,  0.9245302 ,
              -1.2688873 ,  1.7100022 ,  0.00406364,  1.1357585 ,
               1.398

In [21]:

def process_batch(batch_observation,
                  batch_actions,
                  length_history, 
                  length_observation,
                  node_parameters,
                  action_parameters,
                  next_history,
                  pred_observation):
    
    node_params_grads = np.zeros((2**max_depth -1, length_history + length_observation + 1))
    action_params_grads = np.zeros((2**max_depth -1, 2))
    next_history_grads = np.zeros((2**max_depth-1, length_history))
    pred_observation_grads = np.zeros((2**max_depth -1, length_observation))
    cumulative_loss = 0

    current_index = 0
    
    for i in range(len(batch_observation)):
        cumulative_loss += \
            loss_function(node_parameters, \
                action_parameters, \
                next_history, \
                pred_observation, \
                current_index, \
                batch_observation[i], \
                batch_actions[i], \
                delta1, \
                delta2)


        
        node_params_grads += \
            grad(loss_function, 0)(node_parameters,
                                   action_parameters,
                                   next_history,
                                   pred_observation,
                                   current_index,
                                   batch_observation[i],
                                   batch_actions[i],
                                   delta1,
                                   delta2)
        
        action_params_grads += \
            grad(loss_function, 1)(node_parameters,
                                   action_parameters,
                                   next_history,
                                   pred_observation,
                                   current_index,
                                   batch_observation[i],
                                   batch_actions[i],
                                   delta1,
                                   delta2)
        next_history_grads += \
            grad(loss_function, 2)(node_parameters,
                                   action_parameters,
                                   next_history,
                                   pred_observation,
                                   current_index,
                                   batch_observation[i],
                                   batch_actions[i],
                                   delta1,
                                   delta2)
        pred_observation_grads += \
            grad(loss_function, 3)(node_parameters,
                                   action_parameters,
                                   next_history,
                                   pred_observation,
                                   current_index,
                                   batch_observation[i],
                                   batch_actions[i],
                                   delta1,
                                   delta2)

        current_index, product_path, l1_norm_sum = \
            root.predict(batch_observation[i],
                    next_history,
                    current_index,
                    node_parameters)
            
    return (node_params_grads, 
            action_params_grads,
            next_history_grads,
            pred_observation_grads,
            cumulative_loss)
            

In [22]:
seq_len = 100
delta1 = 0.001
delta2 = 0.01
batch_observation = np.random.sample((seq_len, length_observation))

true_actions = np.random.randint(0, 2, seq_len)
batch_actions = np.zeros((seq_len, true_actions.max()+1))
batch_actions[np.arange(seq_len),true_actions] = 1


In [23]:
gradients = process_batch(batch_observation, 
              batch_actions, 
              length_history, 
              length_observation,
              node_parameters,
              action_parameters,
              next_history,
              pred_observation)

In [24]:
# node_parameters,
# action_parameters,
# next_history,
# pred_observation,
gradients[2]

DeviceArray([[ 1.4554350e-01,  9.0336874e-02,  2.8860383e-02,
              -5.4120108e-02, -4.0215999e-03, -2.9870879e-02,
               2.6980756e-02,  1.5608296e-01, -1.0700166e-02,
              -8.5120961e-02],
             [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
               0.0000000e+00],
             [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
               0.0000000e+00],
             [-1.1615208e+00, -5.9248686e-01,  7.5844604e-01,
              -6.6037351e-01, -2.8949120e+00,  8.7489724e-01,
              -1.7313139e+00, -3.0012071e+00,  1.6553259e+00,
               2.0595407e+00],
             [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        

In [25]:
def generate_sequence(min_length, max_length, length_observation):
    length = np.random.randint(min_length, max_length)
    sequence_observations = np.zeros((max_length, length_observation))
    observations = 100 * np.random.sample((length, length_observation))
#     base[0:max_length - 1] = sequence_observations
    for i in range(length):
        sequence_observations[i] = observations[i]
    
    sequence_actions = np.zeros((max_length, 2))

    true_actions = np.random.randint(0, 2, length)
    actions = np.zeros((length, true_actions.max()+1))
    actions[np.arange(length),true_actions] = 1
    
    for i in range(length):
        sequence_actions[i] = actions[i]
    
    
    return length, sequence_observations, sequence_actions

In [29]:
# generate_sequence(min_length=5, max_length=10, length_observation=10)

In [60]:
import matplotlib.pyplot as plt