# Real stuff from here

In [1]:
# let us load the pytorch libraries
import torch
import torchvision
# Loading submodules
import torch.nn as nn
#Loading variable class for differentiable params
from torch.autograd import Variable
import torch.optim as optim
import pandas as pd
import numpy as np

In [2]:
#defining params of training
batch_size_train = 256
batch_size_test = 256

In [3]:
# deifne train_loader and test_loader utilities
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True, drop_last=True )

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True , drop_last = True)

In [4]:
# defining feature_selection 
class feature_selection_node(nn.Module):
    
    def __init__(self,number_of_trees ,batch_size):
        # define trainable params here
        super(feature_selection_node, self).__init__()
        self.num_of_trees = number_of_trees
        self.attention_mask = torch.nn.Parameter \
                    (data = torch.Tensor(number_of_trees, 28 *28),requires_grad=True )
        self.attention_mask.data.uniform_(-1.0, 1.0)
        self.batch = batch_size

    def forward(self, x):
        x = x.view(-1,28 *28)
        attention_tmp = torch.sigmoid(self.attention_mask)
        #scatter the mask here by only keeping the top 200 values and setting rest to 0
        topk, idx = torch.topk(attention_tmp, k=200, dim=1)
        attention = torch.zeros(self.num_of_trees, 28 *28)
        attention.scatter_(1, idx, topk )
        # multipy and send attention and product
        return_value = torch.zeros(self.batch,self.num_of_trees,28 *28)
        for mask_index in range(0,self.num_of_trees):
            return_value[:,mask_index,:] = x * attention[mask_index]
        return return_value , attention

In [5]:
# defining feauture_selection 
class decision_node(nn.Module):
    
    def __init__(self, number_of_trees,max_num_of_leaf_nodes , classes ,batch):
        super(decision_node, self).__init__()
        self.leaf = max_num_of_leaf_nodes
        self.tree = number_of_trees
        self.classes = classes
        self.batch = batch

        self.symbolic_path_weights = nn.Linear(28*28,max_num_of_leaf_nodes,bias = True)
                                        
        self.hardtanh = nn.Hardtanh()
        self.softmax = nn.Softmax(dim =-1)
        self.contribution= torch.nn.Parameter(data = torch.Tensor(number_of_trees, \
                                        max_num_of_leaf_nodes , classes),requires_grad=True )
        self.contribution.data.uniform_(-1.0, 1.0)
                          
        #Define trainabale params here

    def forward(self, x):
        # use trainable params to define computations here
        class_value = torch.randn(self.batch,self.tree,self.leaf, self.classes)
        symbolic_paths =     self.hardtanh(self.symbolic_path_weights(x))
        contribution_batch = self.contribution.view(1,self.tree,self.leaf, self.classes)


        for tree_index in range(0,self.tree):
            for decision_index in range(0, self.leaf):
                
                class_value[:,tree_index,decision_index,:] = torch.mm(
                symbolic_paths[:,tree_index,decision_index].view(-1,1) , \
                        self.contribution[tree_index,  decision_index ].view(1,-1))
        # here wecould have taken cumsum also
        class_value =  self.softmax(class_value)
        class_value = 1.0 - class_value * class_value
        class_value = class_value.sum(dim =-1)
        return  symbolic_paths , class_value



In [6]:
mask = feature_selection_node(100,batch_size_train)
decision = decision_node(100,200,10,batch_size_train)
params = list(mask.parameters())+ list(decision.parameters())
optimizer = optim.SGD(params, lr=1e-3,momentum=.5)


In [7]:
n_epochs = 3
log_interval = 10
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]
def frequency(d):
    dic ={}
    for item in d:
        if item in dic.keys():
            dic[item] = dic[item] + 1
        else:
            dic[item] = 1
    dic = {"values" :dic.keys(),"count" :dic.values()}
    df= pd.DataFrame.from_dict(dic, orient='index').transpose().sort_values(["values"])
    df["cum"] = df["count"]/df["count"].sum()
    value = df["cum"].values
    value = torch.from_numpy(value).float()
    value = 1-value *value
    value = value.sum(-1)
    return value
def train(epoch):
        mask.train()
        decision.train()
        flag = torch.ones(256,100,200)
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            masked_output , attention = mask(data)
            decision_output, weights= decision(masked_output)
            weights_numpy = weights.detach().numpy()
            weights_numpy= np.roll(weights_numpy, 1, axis=-1)
            weights_numpy[:,:,0] = frequency(target.numpy())
            weights_output = torch.from_numpy(weights_numpy).float()
            loss = torch.nn.MarginRankingLoss(margin=0.0000001)(weights_output,weights,flag )
            loss.backward()
            optimizer.step()
            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
                train_losses.append(loss.item())
                train_counter.append(
                (batch_idx*batch_size_train) + ((epoch-1)*len(train_loader.dataset)))

In [None]:
mask = feature_selection_node(100,batch_size_train)
decision = decision_node(100,200,10,batch_size_train)
params = list(mask.parameters())+ list(decision.parameters())
optimizer = optim.SGD(params, lr=1e-3,momentum=.5)

for epoch in range(1, n_epochs+1):
    train(epoch)
    

