In [None]:
#@title Imports
%reset -f 

In [None]:
import pandas as pd

In [None]:
import pylab
import scipy.io

In [None]:


import numpy as np
from itertools import product as cartesian_prod

import matplotlib.pyplot as plt


from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import os
import argparse
import sys

from sklearn.svm import SVC
np.set_printoptions(precision=2)


def sigmoid(u):
    u = np.maximum(u,-100)
    u = np.minimum(u,100)
    return 1/(1+np.exp(-u))


In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import pairwise_distances

In [None]:
class Args:
    def __init__(self):
        self.numlayer=4
        self.numnodes=10
        self.beta=5.
        self.lr=1.
        

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

**Variable parameters**

In [None]:
#@title Synthetic data
def set_npseed(seed):
    np.random.seed(seed)


def set_torchseed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


#classification data

def data_gen_decision_tree(num_data=1000, dim=2, seed=0, w_list=None, b_list=None,vals=None, num_levels=2):        
    set_npseed(seed=seed)

    # Construct a complete decision tree with 2**num_levels-1 internal nodes,
    # e.g. num_levels=2 means there are 3 internal nodes.
    # w_list, b_list is a list of size equal to num_internal_nodes
    # vals is a list of size equal to num_leaf_nodes, with values +1 or 0
    num_internal_nodes = 2**num_levels - 1
    num_leaf_nodes = 2**num_levels
    stats = np.zeros(num_internal_nodes+num_leaf_nodes) #stores the num of datapoints at each node so at 0(root) all data points will be present

    if vals is None: #when val i.e., labels are not provided make the labels dynamically
        vals = np.arange(0,num_internal_nodes+num_leaf_nodes,1,dtype=np.int32)%2 #assign 0 or 1 label to the node based on whether its numbering is even or odd
        vals[:num_internal_nodes] = -99 #we put -99 to the internal nodes as only the values of leaf nodes are counted

    if w_list is None: #if the w values of the nodes (hyperplane eqn) are not provided then generate dynamically
        w_list = np.random.standard_normal((num_internal_nodes, dim))
        w_list = w_list/np.linalg.norm(w_list, axis=1)[:, None] #unit norm w vects
        b_list = np.zeros((num_internal_nodes))

    '''
    np.random.random_sample
    ========================
    Return random floats in the half-open interval [0.0, 1.0).

    Results are from the "continuous uniform" distribution over the
    stated interval.  To sample :math:`Unif[a, b), b > a` multiply
    the output of `random_sample` by `(b-a)` and add `a`::

        (b - a) * random_sample() + a
    '''

#     data_x = np.random.random_sample((num_data, dim))*2 - 1. #generate the datas in range -1 to +1
#     relevant_stats = data_x @ w_list.T + b_list #stores the x.wT+b value of each nodes for all data points(num_data x num_nodes) to check if > 0 i.e will follow right sub tree route or <0 and will follow left sub tree route
#     curr_index = np.zeros(shape=(num_data), dtype=int) #stores the curr index for each data point from root to leaf. So initially a datapoint starts from root but then it can go to right or left if it goes to right its curr index will become 2 from 0 else 1 from 0 then in next iteration from say 2 it goes to right then it will become 6

    data_x = np.random.standard_normal((num_data, dim))
    data_x /= np.sqrt(np.sum(data_x**2, axis=1, keepdims=True))
    relevant_stats = data_x @ w_list.T + b_list
    curr_index = np.zeros(shape=(num_data), dtype=int)
    
    for level in range(num_levels):
        nodes_curr_level=list(range(2**level - 1,2**(level+1)-1  ))
        for el in nodes_curr_level:
#             b_list[el]=-1*np.median(relevant_stats[curr_index==el,el])
            relevant_stats[:,el] += b_list[el]
        decision_variable = np.choose(curr_index, relevant_stats.T) #based on the curr index will choose the corresponding node value of the datapoint

        # Go down and right if wx+b>0 down and left otherwise.
        # i.e. 0 -> 1 if w[0]x+b[0]<0 and 0->2 otherwise
        curr_index = (curr_index+1)*2 - (1-(decision_variable > 0)) #update curr index based on the desc_variable
        

    bound_dist = np.min(np.abs(relevant_stats), axis=1) #finds the abs value of the minm node value of a datapoint. If some node value of a datapoint is 0 then that data point exactly passes through a hyperplane and we remove all such datapoints
    thres = threshold
    labels = vals[curr_index] #finally labels for each datapoint is assigned after traversing the whole tree

    data_x_pruned = data_x[bound_dist>thres] #to distingush the hyperplanes seperately for 0 1 labels (classification)
    #removes all the datapoints that passes through a node hyperplane
    labels_pruned = labels[bound_dist>thres]
    relevant_stats = np.sign(data_x_pruned @ w_list.T + b_list) #storing only +1 or -1 for a particular node if it is active or not
    nodes_active = np.zeros((len(data_x_pruned),  num_internal_nodes+num_leaf_nodes), dtype=np.int32) #stores node actv or not for a data

    for node in range(num_internal_nodes+num_leaf_nodes):
        if node==0:
            stats[node]=len(relevant_stats) #for root node all datapoints are present
            nodes_active[:,0]=1 #root node all data points active status is +1
            continue
        parent = (node-1)//2
        nodes_active[:,node]=nodes_active[:,parent]
        right_child = node-(parent*2)-1 # 0 means left, 1 means right 1 has children 3,4
        #finds if it is a right child or left of the parent
        if right_child==1:
            nodes_active[:,node] *= relevant_stats[:,parent]>0 #if parent node val was >0 then this right child of parent is active
        if right_child==0:
            nodes_active[:,node] *= relevant_stats[:,parent]<0 #else left is active
        stats = nodes_active.sum(axis=0) #updates the status i.e., no of datapoints active in that node (root has all active then gradually divided in left right)
    return ((data_x_pruned, labels_pruned), (w_list, b_list, vals), stats)

In [None]:
# print(train_data.shape)
# print(train_data_labels.shape)

# print(test_data.shape)
# print(test_data_labels.shape)


In [None]:
class DLGN_FC(nn.Module):
    def __init__(self, input_dim=None, output_dim=None, num_hidden_nodes=[], beta=30, mode='pwc'):		
        super(DLGN_FC, self).__init__()
        self.num_hidden_layers = len(num_hidden_nodes)
        self.beta=beta  # Soft gating parameter
        self.mode = mode
        self.num_nodes=[input_dim]+num_hidden_nodes+[output_dim]
        self.gating_layers=nn.ModuleList()
        self.value_layers=nn.Parameter(torch.randn([1]+num_hidden_nodes)/100.) #[1, 12, 12, 12, 12]
        self.num_layer = len(num_hidden_nodes)
        self.num_hidden_nodes = num_hidden_nodes
        for i in range(self.num_hidden_layers+1):
            if i!=self.num_hidden_layers:
                temp = nn.Linear(self.num_nodes[0], self.num_nodes[i+1], bias=False)
                self.gating_layers.append(temp)

    def set_parameters_with_mask(self, to_copy, parameter_masks):
        # self and to_copy are DLGN_FC objects with same architecture
        # parameter_masks is compatible with dict(to_copy.named_parameters())
        for (name, copy_param) in to_copy.named_parameters():
            copy_param = copy_param.clone().detach()
            orig_param  = self.state_dict()[name]
            if name in parameter_masks:
                param_mask = parameter_masks[name]>0
                orig_param[param_mask] = copy_param[param_mask]
            else:
                orig_param = copy_param.data.detach()

    def return_gating_functions(self):
        effective_weights = []
        for i in range(self.num_hidden_layers):
            curr_weight = self.gating_layers[i].weight.detach().clone()
            # curr_weight /= torch.norm(curr_weight, dim=1, keepdim=True)
            effective_weights.append(curr_weight)
        return effective_weights
        # effective_weights (and effective biases) is a list of size num_hidden_layers


    def forward(self, x):
        for el in self.parameters():
            if el.is_cuda:
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')
        values=[torch.ones(x.shape).to(device)]
        for i in range(self.num_hidden_layers):
            fiber = [len(x)]+[1]*self.num_layer
#             print("fiber:",fiber)
            fiber[i+1] = self.num_hidden_nodes[i]
#             print("fiber:",fiber)
            fiber = tuple(fiber)
#             print("fiber:",fiber)
            gate_score = torch.sigmoid( self.beta*(x@self.gating_layers[i].weight.T))#/
                #   torch.norm(self.gating_layers[i].weight, dim=1, keepdim=True).T) 
#             print("gate_score:",gate_score.shape)
            gate_score = gate_score.reshape(fiber) 
#             print("gate_score:",gate_score.shape)
            if i==0:
                cp = gate_score
#                 print("cp:",cp.shape)
            else:
                cp = cp*gate_score 
#                 print("cp:",cp.shape)
#             print("return:",torch.sum(cp*self.value_layers, dim=(1,2,3,4)).shape)
        return torch.sum(cp*self.value_layers, dim=(1,2,3,4))

In [None]:
#@title Train DLGN model
def train_dlgn (DLGN_obj, train_data_curr,vali_data_curr,test_data_curr,
                train_labels_curr,test_labels_curr,vali_labels_curr,
                parameter_mask=dict()):
    # DLGN_obj is the initial network
    # parameter_mask is a dictionary compatible with dict(DLGN_obj.named_parameters())
    # if a key corresponding to a named_parameter is not present it is assumed to be all ones (i.e it will be updated)

    # Assuming that we are on a CUDA machine, this should print a CUDA device:

    # Speed up of a factor of over 40 by using GPU instead of CPU
    # Final train loss of 0.02 and test acc of 74%
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')
    DLGN_obj.to(device)

    criterion = nn.CrossEntropyLoss()




    optimizer = optim.Adam(DLGN_obj.parameters(), lr=lr)



    train_data_torch = torch.Tensor(train_data_curr)
    vali_data_torch = torch.Tensor(vali_data_curr)
    test_data_torch = torch.Tensor(test_data_curr)

    train_labels_torch = torch.tensor(train_labels_curr, dtype=torch.int64)
    test_labels_torch = torch.tensor(test_labels_curr, dtype=torch.int64)
    vali_labels_torch = torch.tensor(vali_labels_curr, dtype=torch.int64)

    num_batches = no_of_batches
    batch_size = len(train_data_curr)//num_batches
    losses=[]
    DLGN_obj_store = []
    best_vali_error = len(vali_labels_curr)


    # print("H3")
    # print(DLGN_params)
    debug_models= []
    train_losses = []
    tepoch = tqdm(range(saved_epochs[-1]+1))
    for epoch in tepoch:  # loop over the dataset multiple times
        if epoch in update_value_epochs:
            # updating the value pathdim vector by optimising 

            train_preds =DLGN_obj(torch.Tensor(train_data_curr).to(device)).reshape((-1,1))
            criterion = nn.CrossEntropyLoss()
            outputs = torch.cat((-1*train_preds,train_preds), dim=1)
            targets = torch.tensor(train_labels_curr, dtype=torch.int64).to(device)

            train_loss = criterion(outputs, targets)
            print("Loss before updating value_net at epoch", epoch, " is ", train_loss)
            print("Total path abs value", torch.abs(DLGN_obj.value_layers.cpu().detach()).sum().numpy())

            ew = DLGN_obj.return_gating_functions()
            cp_feat1 = sigmoid(beta*np.dot(train_data_curr,ew[0].cpu().T).reshape(-1,num_neuron,1,1,1))
#             print("cp_feat1:",cp_feat1.shape)
            cp_feat2 = sigmoid(beta*np.dot(train_data_curr,ew[1].cpu().T).reshape(-1,1,num_neuron,1,1))
#             print("cp_feat2:",cp_feat2.shape)
            cp_feat3 = sigmoid(beta*np.dot(train_data_curr,ew[2].cpu().T).reshape(-1,1,1,num_neuron,1))
#             print("cp_feat3:",cp_feat3.shape)
            cp_feat4 = sigmoid(beta*np.dot(train_data_curr,ew[3].cpu().T).reshape(-1,1,1,1,num_neuron))
#             print("cp_feat4:",cp_feat4.shape)
            cp_feat = cp_feat1 * cp_feat2 * cp_feat3 * cp_feat4
#             print("cp_feat:",cp_feat.shape)
            cp_feat_vec = cp_feat.reshape((len(cp_feat),-1))
#             print("cp_feat_vec:",cp_feat_vec.shape)

            clf = LogisticRegression(C=0.03, fit_intercept=False,max_iter=1000, penalty="l1", solver='liblinear')
            clf.fit(2*cp_feat_vec, train_labels_curr)
            value_wts  = clf.decision_function(np.eye(num_neuron**num_layer)).reshape(1,num_neuron,num_neuron,num_neuron,num_neuron)
#             print("value_wts:",value_wts.shape)
            A= DLGN_obj.value_layers.detach()
            A[:] = torch.Tensor(value_wts)

            train_preds =DLGN_obj(torch.Tensor(train_data_curr).to(device)).reshape((-1,1))
            criterion = nn.CrossEntropyLoss()
            outputs = torch.cat((-1*train_preds,train_preds), dim=1)
            targets = torch.tensor(train_labels_curr, dtype=torch.int64).to(device)
            train_loss = criterion(outputs, targets)
            print("Loss after updating value_net at epoch", epoch, " is ", train_loss)			
            print("Total path abs value", torch.abs(DLGN_obj.value_layers.cpu().detach()).sum().numpy())
            if epoch in saved_epochs:
                DLGN_obj_copy = deepcopy(DLGN_obj)
                DLGN_obj_copy.to(torch.device('cpu'))
                DLGN_obj_store.append(DLGN_obj_copy)
                train_losses.append(train_loss)

        for batch_start in range(0,len(train_data_curr),batch_size):
            if (batch_start+batch_size)>len(train_data_curr):
                break
            optimizer.zero_grad()
            inputs = train_data_torch[batch_start:batch_start+batch_size]
            targets = train_labels_torch[batch_start:batch_start+batch_size].reshape(batch_size)
            criterion = nn.CrossEntropyLoss()
            inputs = inputs.to(device)
            targets = targets.to(device)
            preds = DLGN_obj(inputs).reshape(-1,1)
            # preds_clone = preds.detach().clone().cpu().numpy()[:,0]
            # targets_clone = targets.detach().clone().cpu().numpy()
            # coeff = (0.5-targets_clone)/(sigmoid(2*preds_clone)-targets_clone)
            # print(coeff.shape)

            # print(coeff.min())
            # print(coeff.mean())
            # print(coeff.max())
            outputs = torch.cat((-1*preds, preds), dim=1)
            loss = criterion(outputs, targets)
            # loss = loss*torch.tensor(coeff, device=device)	
            # loss = loss.mean()		
            loss.backward()
            for name,param in DLGN_obj.named_parameters():
                if "val" in name:
                    param.grad *= 0.0
                if "gat" in name:
                    param.grad *= 1.0
            optimizer.step()

        train_preds =DLGN_obj(torch.Tensor(train_data_curr).to(device)).reshape(-1,1)
        criterion = nn.CrossEntropyLoss()
        outputs = torch.cat((-1*train_preds,train_preds), dim=1)
        targets = torch.tensor(train_labels_curr, dtype=torch.int64).to(device)
        train_loss = criterion(outputs, targets)
        if epoch%5 == 0:
            print("Loss after updating at epoch ", epoch, " is ", train_loss)
            test_preds =DLGN_obj(test_data_torch.to(device)).reshape(-1,1)
            test_preds = test_preds.detach().cpu().numpy()
            print("Test error=",np.sum(test_labels_curr != (np.sign(test_preds[:,0])+1)//2 ))
        if train_loss < 0.005:
            break
        if np.isnan(train_loss.detach().cpu().numpy()):
            break

        losses.append(train_loss.cpu().detach().clone().numpy())
        inputs = vali_data_torch.to(device)
        targets = vali_labels_torch.to(device)
        preds =DLGN_obj(inputs).reshape(-1,1)
        vali_preds = torch.cat((-1*preds, preds), dim=1)
        vali_preds = torch.argmax(vali_preds, dim=1)
        vali_error= torch.sum(targets!=vali_preds)
        if vali_error < best_vali_error:
            DLGN_obj_return = deepcopy(DLGN_obj)
            best_vali_error = vali_error
    plt.figure()
    plt.title("DLGN loss vs epoch")
    plt.plot(losses)
    # 	if not os.path.exists('figures'):
    # 		os.mkdir('figures')

    # 	filename = 'figures/'+filename_suffix +'.pdf'
    # 	plt.savefig(filename)
    DLGN_obj_return.to(torch.device('cpu'))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')
    return train_losses, DLGN_obj_return, DLGN_obj_store, losses, debug_models

**Training a DLGN model**

In [None]:
args =  Args()

num_layer = args.numlayer
num_neuron = args.numnodes
beta = args.beta
lr=args.lr

saved_epochs = list(range(0,301,10)) + list(range(301,15301,500))
# saved_epochs = list(range(0,501,10))
update_value_epochs = list(range(0,15301,100))# 


no_of_batches=10 #[1,10,100]
weight_decay=0.0
num_hidden_nodes=[num_neuron]*num_layer

# Define dictionaries
seed=365
num_levels=4
threshold = 0 #data seperation distance

optimizer_name ='Adam'
modep='pwc' 
output_dim=1

# saved_epochs = list(range(0,num_epoch+1,num_epoch//10))
weight_decay=0.0

data_configs = [
    {"input_dim": 20, "num_data": 40000},
    {"input_dim": 100, "num_data": 60000},
    {"input_dim": 500, "num_data": 100000}
]

for config in data_configs:
    input_dim = config["input_dim"]
    num_data = config["num_data"]
    
    print("==========input_dim:",input_dim,"==============num_data:",num_data)

    
    
    ((data_x, labels), (w_list, b_list, vals), stats) = data_gen_decision_tree(
                                                dim=input_dim, seed=seed, num_levels=num_levels,
                                                num_data=num_data)
    seed_set=seed
    w_list_old = np.array(w_list)
    b_list_old = np.array(b_list)
    print(sum(labels==1))
    print(sum(labels==0))
#     print(labels.shape)
#     print(vals)
#     print(stats)
    print("Seed= ",seed_set)
    num_data = len(data_x)
    num_train= num_data//2
    num_vali = num_data//4
    num_test = num_data//4
    train_data = data_x[:num_train,:]
    train_data_labels = labels[:num_train]

    vali_data = data_x[num_train:num_train+num_vali,:]
    vali_data_labels = labels[num_train:num_train+num_vali]

    test_data = data_x[num_train+num_vali :,:]
    test_data_labels = labels[num_train+num_vali :]    

    print("---" * 30)
    set_torchseed(41972)
    # set_torchseed(5612)
    DLGN_init= DLGN_FC(input_dim=input_dim, output_dim=1, num_hidden_nodes=num_hidden_nodes, beta=beta)

    train_parameter_masks=dict()
    
    for name,parameter in DLGN_init.named_parameters():
        if "val" in name:
            train_parameter_masks[name]=torch.ones_like(parameter)# Updating all value network layers
        if "gat" in name:
            train_parameter_masks[name]=torch.ones_like(parameter)
        train_parameter_masks[name].to(device)


        






    set_torchseed(5000)
    train_losses, DLGN_obj_final, DLGN_obj_store, losses , debug_models= train_dlgn(train_data_curr=train_data,
                                                vali_data_curr=vali_data,
                                                test_data_curr=test_data,
                                                train_labels_curr=train_data_labels,
                                                vali_labels_curr=vali_data_labels,
                                                test_labels_curr=test_data_labels,
                                                DLGN_obj=deepcopy(DLGN_init),
                                                parameter_mask=train_parameter_masks,
                                                )


    torch.cuda.empty_cache() 
    losses=np.array(losses)
    
    
    device=torch.device('cpu')
    train_preds =DLGN_obj_final(torch.Tensor(train_data).to(device)).reshape(-1,1)
    criterion = nn.CrossEntropyLoss()
    outputs = torch.cat((-1*train_preds,train_preds), dim=1)
    targets = torch.tensor(train_data_labels, dtype=torch.int64)
    train_loss = criterion(outputs, targets)
    train_preds = train_preds.detach().numpy()
    # filename = 'outputs/'+filename_suffix+'.txt'
    # original_stdout = sys.stdout
    Train_error = np.sum(train_data_labels != (np.sign(train_preds[:,0])+1)//2)
    Num_train_data = len(train_data_labels)
    print("Train error=",Train_error)
    print("Num_train_data=",Num_train_data)
    print("Train_acc:",1-Train_error/Num_train_data)
    
    test_preds =DLGN_obj_final(torch.Tensor(test_data)).reshape(-1,1)
    test_preds = test_preds.detach().numpy()
    # filename = 'outputs/'+filename_suffix+'.txt'
    # original_stdout = sys.stdout
    # with open(filename,'a') as f:
    #     sys.stdout = f
    #     print("Test error=",np.sum(test_data_labels != (np.sign(test_preds[:,0])+1)//2 ))
    #     print("Num_test_data=",len(test_data_labels))
    #     sys.stdout = original_stdout

    Test_error = np.sum(test_data_labels != (np.sign(test_preds[:,0])+1)//2)
    Num_test_data = len(test_data_labels)
    print("Test error=",Test_error)
    print("Num_test_data=",Num_test_data)
    print("Test_acc:",1-Test_error/Num_test_data)

# print(DLGN_obj_store[-1].beta)
    