**Import**

In [14]:
#Import
%reset -f
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import copy
import numpy as np
np.set_printoptions(precision=4)
import os
import random
from copy import deepcopy
import torchvision
import torchvision.transforms as transforms

**Device**

In [15]:
#Device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


**Synthetic Dataset**

In [16]:
#Synthetic Dataset
def set_npseed(seed):
    np.random.seed(seed)


def set_torchseed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


#classification data

def data_gen_decision_tree(num_data=1000, dim=2, seed=0, w_list=None, b_list=None,vals=None, num_levels=2):        
    set_npseed(seed=seed)

    # Construct a complete decision tree with 2**num_levels-1 internal nodes,
    # e.g. num_levels=2 means there are 3 internal nodes.
    # w_list, b_list is a list of size equal to num_internal_nodes
    # vals is a list of size equal to num_leaf_nodes, with values +1 or 0
    num_internal_nodes = 2**num_levels - 1
    num_leaf_nodes = 2**num_levels
    stats = np.zeros(num_internal_nodes+num_leaf_nodes) #stores the num of datapoints at each node so at 0(root) all data points will be present

    if vals is None: #when val i.e., labels are not provided make the labels dynamically
        vals = np.arange(0,num_internal_nodes+num_leaf_nodes,1,dtype=np.int32)%2 #assign 0 or 1 label to the node based on whether its numbering is even or odd
        vals[:num_internal_nodes] = -99 #we put -99 to the internal nodes as only the values of leaf nodes are counted

    if w_list is None: #if the w values of the nodes (hyperplane eqn) are not provided then generate dynamically
        w_list = np.random.standard_normal((num_internal_nodes, dim))
        w_list = w_list/np.linalg.norm(w_list, axis=1)[:, None] #unit norm w vects
        b_list = np.zeros((num_internal_nodes))

    

#     data_x = np.random.random_sample((num_data, dim))*2 - 1. #generate the datas in range -1 to +1
#     relevant_stats = data_x @ w_list.T + b_list #stores the x.wT+b value of each nodes for all data points(num_data x num_nodes) to check if > 0 i.e will follow right sub tree route or <0 and will follow left sub tree route
#     curr_index = np.zeros(shape=(num_data), dtype=int) #stores the curr index for each data point from root to leaf. So initially a datapoint starts from root but then it can go to right or left if it goes to right its curr index will become 2 from 0 else 1 from 0 then in next iteration from say 2 it goes to right then it will become 6

    data_x = np.random.standard_normal((num_data, dim))
    data_x /= np.sqrt(np.sum(data_x**2, axis=1, keepdims=True))
    relevant_stats = data_x @ w_list.T + b_list
    curr_index = np.zeros(shape=(num_data), dtype=int)
    
    for level in range(num_levels):
        nodes_curr_level=list(range(2**level - 1,2**(level+1)-1  ))
        for el in nodes_curr_level:
#             b_list[el]=-1*np.median(relevant_stats[curr_index==el,el])
            relevant_stats[:,el] += b_list[el]
        decision_variable = np.choose(curr_index, relevant_stats.T) #based on the curr index will choose the corresponding node value of the datapoint

        # Go down and right if wx+b>0 down and left otherwise.
        # i.e. 0 -> 1 if w[0]x+b[0]<0 and 0->2 otherwise
        curr_index = (curr_index+1)*2 - (1-(decision_variable > 0)) #update curr index based on the desc_variable
        

    bound_dist = np.min(np.abs(relevant_stats), axis=1) #finds the abs value of the minm node value of a datapoint. If some node value of a datapoint is 0 then that data point exactly passes through a hyperplane and we remove all such datapoints
    thres = threshold
    labels = vals[curr_index] #finally labels for each datapoint is assigned after traversing the whole tree

    data_x_pruned = data_x[bound_dist>thres] #to distingush the hyperplanes seperately for 0 1 labels (classification)
    #removes all the datapoints that passes through a node hyperplane
    labels_pruned = labels[bound_dist>thres]
    relevant_stats = np.sign(data_x_pruned @ w_list.T + b_list) #storing only +1 or -1 for a particular node if it is active or not
    nodes_active = np.zeros((len(data_x_pruned),  num_internal_nodes+num_leaf_nodes), dtype=np.int32) #stores node actv or not for a data

    for node in range(num_internal_nodes+num_leaf_nodes):
        if node==0:
            stats[node]=len(relevant_stats) #for root node all datapoints are present
            nodes_active[:,0]=1 #root node all data points active status is +1
            continue
        parent = (node-1)//2
        nodes_active[:,node]=nodes_active[:,parent]
        right_child = node-(parent*2)-1 # 0 means left, 1 means right 1 has children 3,4
        #finds if it is a right child or left of the parent
        if right_child==1:
            nodes_active[:,node] *= relevant_stats[:,parent]>0 #if parent node val was >0 then this right child of parent is active
        if right_child==0:
            nodes_active[:,node] *= relevant_stats[:,parent]<0 #else left is active
        stats = nodes_active.sum(axis=0) #updates the status i.e., no of datapoints active in that node (root has all active then gradually divided in left right)
    return ((data_x_pruned, labels_pruned), (w_list, b_list, vals), stats)

**DLGN and DLGN_SF Model**

In [17]:
#DLGN and DLGN_SF Model
class DLGN_FC(nn.Module):
    def __init__(self, input_dim=None, output_dim=None, num_hidden_nodes=[], beta=30, dlgn_mode='dlgn_sf', mode='pwc'):		
        super(DLGN_FC, self).__init__()
        self.num_hidden_layers = len(num_hidden_nodes)
        self.beta=beta  # Soft gating parameter
        self.dlgn_mode = dlgn_mode
        self.mode = mode
        self.num_nodes=[input_dim]+num_hidden_nodes+[output_dim]
        self.gating_layers=nn.ModuleList()
        self.value_layers=nn.ModuleList()

        for i in range(self.num_hidden_layers+1):
            if i!=self.num_hidden_layers:
                if self.dlgn_mode == 'dlgn_sf':
                    temp = nn.Linear(self.num_nodes[0], self.num_nodes[i+1], bias=False)
                else :
                    temp = nn.Linear(self.num_nodes[i], self.num_nodes[i+1], bias=False)
                self.gating_layers.append(temp)
            temp = nn.Linear(self.num_nodes[i], self.num_nodes[i+1], bias=False)
            self.value_layers.append(temp)
    
    def set_parameters_with_mask(self, to_copy, parameter_masks):
        # self and to_copy are DLGN_FC objects with same architecture
        # parameter_masks is compatible with dict(to_copy.named_parameters())
        for (name, copy_param) in to_copy.named_parameters():
            copy_param = copy_param.clone().detach()
            orig_param  = self.state_dict()[name]
            if name in parameter_masks:
                param_mask = parameter_masks[name]>0
                orig_param[param_mask] = copy_param[param_mask]
            else:
                orig_param = copy_param.data.detach()

    def return_gating_functions(self):
        effective_weights = []
        for i in range(self.num_hidden_layers):
            curr_weight = self.gating_layers[i].weight.detach().clone()
#             curr_weight /= torch.norm(curr_weight, dim=1, keepdim=True)
            if self.dlgn_mode=='dlgn_sf':
                effective_weights.append(curr_weight)
            else:
                if i==0:
                    effective_weights.append(curr_weight)
                else:
                    effective_weights.append(torch.matmul(curr_weight,effective_weights[-1]))
        return effective_weights
        # effective_weights (and effective biases) is a list of size num_hidden_layers


    def forward(self, x):
        gate_scores=[x]

        for el in self.parameters():
            if el.is_cuda:
                device = torch.device('cuda:0')
                # device = torch.device('cpu')
            else:
                device = torch.device('cpu')
        if self.mode=='pwc':
            values=[torch.ones(x.shape).to(device)]
        else:
            values=[x.to(device)]

        for i in range(self.num_hidden_layers):
            if self.dlgn_mode=='dlgn_sf':
                gate_scores.append( (x@self.gating_layers[i].weight.T) )
            else:
                gate_scores.append(self.gating_layers[i].to(device)(gate_scores[-1].to(device)))
            curr_gate_on_off = torch.sigmoid(self.beta * gate_scores[-1])
            values.append(self.value_layers[i](values[-1])*curr_gate_on_off)
        values.append(self.value_layers[self.num_hidden_layers](values[-1]))
        # Values is a list of size 1+num_hidden_layers+1
        #gate_scores is a list of size 1+num_hidden_layers
        return values,gate_scores

**Train DLGN Model**

In [18]:
#@title Train DLGN model
def train_dlgn (DLGN_obj, train_data_curr,vali_data_curr,test_data_curr,
				train_labels_curr,test_labels_curr,vali_labels_curr,num_epoch=1,
				parameter_mask=dict()):
    # DLGN_obj is the initial network
    # parameter_mask is a dictionary compatible with dict(DLGN_obj.named_parameters())
    # if a key corresponding to a named_parameter is not present it is assumed to be all ones (i.e it will be updated)

    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    # device='cpu'
    # Speed up of a factor of over 40 by using GPU instead of CPU

    print("train_data_curr inside train_dlgn:",train_data_curr.shape)
    set_torchseed(seed)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')
    DLGN_obj.to(device)
    criterion = nn.CrossEntropyLoss()
#     optimizer = optim.SGD(DLGN_obj.parameters(), lr=lr)
    optimizer = optim.Adam(DLGN_obj.parameters(), lr=lr)

    train_data_torch = torch.Tensor(train_data_curr)
    vali_data_torch = torch.Tensor(vali_data_curr)
    test_data_torch = torch.Tensor(test_data_curr)

    train_labels_torch = torch.tensor(train_labels_curr, dtype=torch.int64)
    test_labels_torch = torch.tensor(test_labels_curr, dtype=torch.int64)
    vali_labels_torch = torch.tensor(vali_labels_curr, dtype=torch.int64)

    num_batches = no_of_batches
    batch_size = len(train_data_curr)//num_batches
    losses=[]
    DLGN_obj_store = []
    best_vali_error = len(vali_labels_curr)
    debug_models= []
    train_losses = []
    running_loss = 0.7*num_batches # initial random loss = 0.7 

    for epoch in tqdm(range(saved_epochs[-1]+1)):  # loop over the dataset multiple times
        if epoch in saved_epochs:
            DLGN_obj_copy = deepcopy(DLGN_obj)
            DLGN_obj_copy.to(torch.device('cpu'))
            DLGN_obj_store.append(DLGN_obj_copy)
            train_outputs_values, train_outputs_gate_scores =DLGN_obj(torch.Tensor(train_data_curr).to(device))
            train_preds = train_outputs_values[-1]
            criterion = nn.CrossEntropyLoss()
            outputs = torch.cat((-1*train_preds,train_preds), dim=1)
            targets = torch.tensor(train_labels_curr, dtype=torch.int64).to(device)
            train_loss = criterion(outputs, targets)

            train_losses.append(train_loss)
            if epoch%100 == 0:
                print(train_loss)
            if train_loss < 5e-6:
                break
            if np.isnan(train_loss.detach().cpu().numpy()):
                break
        running_loss = 0.0
        for batch_start in range(0,len(train_data_curr),batch_size):
            if (batch_start+batch_size)>len(train_data_curr):
                break
            optimizer.zero_grad()
            inputs = train_data_torch[batch_start:batch_start+batch_size]
            targets = train_labels_torch[batch_start:batch_start+batch_size].reshape(batch_size)
            inputs = inputs.to(device)
            targets = targets.to(device)
            values,gate_scores = DLGN_obj(inputs)
            outputs = torch.cat((-1*values[-1], values[-1]), dim=1)
            loss = criterion(outputs, targets)			
            loss.backward()
            for name,param in DLGN_obj.named_parameters():
                parameter_mask[name] = parameter_mask[name].to(device)
                param.grad *= parameter_mask[name]   
                if "gat" in name and epoch>x_epoch:
                    param.grad *= 0.
            optimizer.step()
            running_loss += loss.item()


        train_outputs_values, train_outputs_gate_scores =DLGN_obj(torch.Tensor(train_data_curr).to(device))
        train_preds = train_outputs_values[-1]
        criterion = nn.CrossEntropyLoss()
        outputs = torch.cat((-1*train_preds,train_preds), dim=1)
        targets = torch.tensor(train_labels_curr, dtype=torch.int64).to(device)
        train_loss = criterion(outputs, targets)

        losses.append(train_loss.cpu().detach().clone().numpy())
        inputs = vali_data_torch.to(device)
        targets = vali_labels_torch.to(device)
        values,gate_scores =DLGN_obj(inputs)
        vali_preds = torch.cat((-1*values[-1], values[-1]), dim=1)
        vali_preds = torch.argmax(vali_preds, dim=1)
        vali_error= torch.sum(targets!=vali_preds)
        if vali_error < best_vali_error:
            DLGN_obj_return = deepcopy(DLGN_obj)
            best_vali_error = vali_error
            
    DLGN_obj_return.to(torch.device('cpu'))
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')
    return train_losses, DLGN_obj_return, DLGN_obj_store, losses, debug_models

**Change this cell to run different configs**

In [20]:
#In this cell change the parameter values to train on different synthetic data with different models.
#Dataset Characteristics

'''
For SDI input_dim=20, num_data=40000
DLGN Best Parameters:
Beta     LR  Hidden Layers       Hidden Nodes            Test Accuracy
3      0.020        5           [20, 20, 20, 20, 20]         0.9605

DLGN-SF Best Parameters:
Beta     LR  Hidden Layers        Hidden Nodes           Test Accuracy
30     0.020        4           [10, 10, 10, 10]              0.9743
'''

'''
For SDII input_dim=100, num_data=60000
DLGN Best Parameters:
Beta     LR  Hidden Layers       Hidden Nodes         Test Accuracy
10     0.010      4             [20, 20, 20, 20]        0.94247

DLGN-SF Best Parameters:
Beta     LR  Hidden Layers        Hidden Nodes          Test Accuracy
3      0.020       4             [10, 10, 10, 10]        0.90293
'''

'''
For SDIII input_dim=500, num_data=100000
DLGN Best Parameters:
Beta     LR  Hidden Layers       Hidden Nodes         Test Accuracy
10      0.001      3             [10, 10, 10]          0.65036

DLGN-SF Best Parameters:
Beta     LR  Hidden Layers        Hidden Nodes        Test Accuracy
 3     0.010        3             [20, 20, 20]          0.63832
'''

seed=365
num_levels=4
threshold = 0 #data seperation distance

optimizer_name ='Adam'
modep='pwc' 
output_dim=1
num_epoch=1500 #number of epochs to run

x_epoch = 1500
saved_epochs = list(range(0,1501,10))
weight_decay=0.0
no_of_batches=10 #[1,10,100]

input_dim = 20 #Synthetic data input dimension
num_data = 40000 #Total data points



print(f"Running code for input_dim={input_dim}, num_data={num_data}")

((data_x, labels), (w_list, b_list, vals), stats) = data_gen_decision_tree(
                                            dim=input_dim, seed=seed, num_levels=num_levels,
                                            num_data=num_data)
seed_set=seed
w_list_old = np.array(w_list)
b_list_old = np.array(b_list)
print(sum(labels==1))
print(sum(labels==0))
print("Seed= ",seed_set)
num_data = len(data_x)
num_train= num_data//2
num_vali = num_data//4
num_test = num_data//4
train_data = data_x[:num_train,:]
train_data_labels = labels[:num_train]

vali_data = data_x[num_train:num_train+num_vali,:]
vali_data_labels = labels[num_train:num_train+num_vali]

test_data = data_x[num_train+num_vali :,:]
test_data_labels = labels[num_train+num_vali :]

dlgn_mode = 'dlgn_sf' #use dlgn_sf for running dlgn-sf model
beta = 10
lr = 0.002

num_hidden_layers = 4
num_hidden_nodes = [20,20,20,20]

print(f"Running code for num_hidden_layers={num_hidden_layers}, num_hidden_nodes={num_hidden_nodes}")

max_no_of_nodes=max(num_hidden_nodes)

set_torchseed(6675)
DLGN_init= DLGN_FC(input_dim=input_dim, output_dim=1, num_hidden_nodes=num_hidden_nodes, beta=beta, dlgn_mode=dlgn_mode)


train_parameter_masks=dict()
for name,parameter in DLGN_init.named_parameters():
    if "val" in name:
        train_parameter_masks[name]=torch.ones_like(parameter) #*0.001 # Updating all value network layers
    if "gat" in name:
        train_parameter_masks[name]=torch.ones_like(parameter)


        # train_parameter_masks[name][:num_neurons_set] *= 0.
    train_parameter_masks[name].to(device)






set_torchseed(5000)
train_losses, DLGN_obj_final, DLGN_obj_store, losses , debug_models= train_dlgn(train_data_curr=train_data,
                                            vali_data_curr=vali_data,
                                            test_data_curr=test_data,
                                            train_labels_curr=train_data_labels,
                                            vali_labels_curr=vali_data_labels,
                                            test_labels_curr=test_data_labels,
                                            DLGN_obj=deepcopy(DLGN_init),
                                            parameter_mask=train_parameter_masks,
                                            )


torch.cuda.empty_cache() 
losses=np.array(losses)


test_outputs_values, test_outputs_gate_scores =DLGN_obj_final(torch.Tensor(test_data))
test_preds = test_outputs_values[-1]
test_preds = test_preds.detach().numpy()
Test_error=np.sum(test_data_labels != (np.sign(test_preds[:,0])+1)//2)
Num_test_data=len(test_data_labels)
print("Test_error=",Test_error)
print("Num_test_data=",Num_test_data)
print("Test Accuracy=",1-Test_error/Num_test_data)


Running code for input_dim=20, num_data=40000
18749
21251
Seed=  365
Running code for num_hidden_layers=4, num_hidden_nodes=[20, 20, 20, 20]
train_data_curr inside train_dlgn: (20000, 20)


  0%|          | 3/1501 [00:00<01:01, 24.22it/s]

tensor(0.6932, device='cuda:0', grad_fn=<NllLossBackward0>)


  7%|▋         | 105/1501 [00:03<00:51, 27.37it/s]

tensor(0.2603, device='cuda:0', grad_fn=<NllLossBackward0>)


 14%|█▎        | 204/1501 [00:07<00:49, 26.40it/s]

tensor(0.1798, device='cuda:0', grad_fn=<NllLossBackward0>)


 20%|██        | 304/1501 [00:11<00:42, 27.94it/s]

tensor(0.1358, device='cuda:0', grad_fn=<NllLossBackward0>)


 27%|██▋       | 403/1501 [00:15<00:40, 26.92it/s]

tensor(0.0948, device='cuda:0', grad_fn=<NllLossBackward0>)


 34%|███▎      | 503/1501 [00:18<00:36, 27.69it/s]

tensor(0.0620, device='cuda:0', grad_fn=<NllLossBackward0>)


 40%|████      | 605/1501 [00:22<00:31, 28.26it/s]

tensor(0.0373, device='cuda:0', grad_fn=<NllLossBackward0>)


 47%|████▋     | 704/1501 [00:26<00:28, 27.50it/s]

tensor(0.0179, device='cuda:0', grad_fn=<NllLossBackward0>)


 53%|█████▎    | 803/1501 [00:29<00:25, 27.45it/s]

tensor(0.1195, device='cuda:0', grad_fn=<NllLossBackward0>)


 60%|██████    | 905/1501 [00:33<00:21, 27.53it/s]

tensor(0.0027, device='cuda:0', grad_fn=<NllLossBackward0>)


 67%|██████▋   | 1003/1501 [00:36<00:17, 28.46it/s]

tensor(0.0010, device='cuda:0', grad_fn=<NllLossBackward0>)


 74%|███████▎  | 1105/1501 [00:40<00:14, 27.64it/s]

tensor(0.0004, device='cuda:0', grad_fn=<NllLossBackward0>)


 80%|████████  | 1205/1501 [00:44<00:10, 28.36it/s]

tensor(0.0002, device='cuda:0', grad_fn=<NllLossBackward0>)


 87%|████████▋ | 1305/1501 [00:48<00:06, 28.09it/s]

tensor(8.0779e-05, device='cuda:0', grad_fn=<NllLossBackward0>)


 93%|█████████▎| 1403/1501 [00:51<00:03, 28.86it/s]

tensor(3.9841e-05, device='cuda:0', grad_fn=<NllLossBackward0>)


100%|██████████| 1501/1501 [00:55<00:00, 27.25it/s]


tensor(2.0273e-05, device='cuda:0', grad_fn=<NllLossBackward0>)
Test_error= 924
Num_test_data= 10000
Test Accuracy= 0.9076
