In [1]:
import numpy as np
from thirdai import distributed_bolt
from thirdai import bolt

In [2]:
from utils import (
    gen_single_sparse_layer_network,
    gen_training_data,
    get_categorical_acc,train_network,train_single_node_distributed_network
)

In [3]:
def build_sparse_hidden_layer_classifier(input_dim, sparse_dim, output_dim, sparsity):
    layers = [
        bolt.FullyConnected(
            dim=sparse_dim,
            sparsity=sparsity,
            activation_function="ReLU",
        ),
        bolt.FullyConnected(dim=output_dim, activation_function="Softmax"),
    ]
    network = bolt.DistributedNetwork(layers=layers, input_dim=input_dim)
    return network

In [4]:
input_dim = 10
hidden_dim = 10
output_dim = 10
network = build_sparse_hidden_layer_classifier(
    input_dim=input_dim, sparse_dim=hidden_dim, output_dim=output_dim, sparsity=1.0
)

Initializing Bolt network...
InputLayer (Layer 0): dim=10
FullyConnectedLayer (Layer 1): dim=10, sparsity=1, act_func=ReLU
FullyConnectedLayer (Layer 2): dim=10, sparsity=1, act_func=Softmax
Initialized Network in 0 seconds


In [5]:
examples, labels = gen_training_data(n_classes=10, n_samples=1000)
train_single_node_distributed_network(network, examples, labels, epochs=10)

Distributed Network initialization done on this Node


In [6]:
x=network.get_indexed_sketch(layer_index=0,compression_density=0.2,is_set_biases=False,seed=2)
print(x)

(array([34, 43, 58,  0,  0,  0, 81,  0, 28, 75,  0, 97,  0,  0,  0,  0,  0,
        0,  0, 11], dtype=int32), array([0.29945943, 0.39426008, 1.099941  , 0.        , 0.        ,
       0.        , 0.800386  , 0.        , 0.4757971 , 0.5506324 ,
       0.        , 1.0056921 , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.6406704 ],
      dtype=float32))


In [7]:
weights=network.get_weights_gradients(0)
print(weights)

[[-6.2827528e-02  1.2125643e-01  1.2198184e+00 -5.4874510e-01
  -5.4618531e-01  7.3684049e-01 -1.0509858e+00 -3.9472710e-03
  -1.9405480e-02 -1.2467495e+00]
 [-4.1024616e-01  6.4067042e-01  8.1941664e-01 -7.8256863e-01
  -4.1121244e-01  3.6604546e-02 -1.2759462e-01 -1.4228326e-02
   6.1198258e-01 -3.1955002e-03]
 [-1.4160929e+00  7.3317230e-02 -1.4296468e+00 -4.8410036e-03
   2.0894678e-02 -3.9055154e-01  3.2895189e-02  6.9374271e-04
   4.7579709e-01  8.2005240e-02]
 [-2.5076577e-02 -2.6229411e-01  3.8832155e-01  2.0307156e-01
   2.9945943e-01 -6.1383605e-01 -7.8046359e-02 -6.5206401e-02
   7.0070118e-01 -7.9829404e-03]
 [-1.2483405e-01  5.0440764e-01 -6.4193942e-03  3.9426008e-01
  -2.8569987e-01  4.5469802e-02  2.7203234e-02 -9.6513711e-02
   2.5608214e-02 -2.2273309e+00]
 [-9.9060111e-02 -1.0530164e+00 -1.1902996e+00  6.6163816e-02
   9.4474268e-01 -4.7885146e-02  4.7076571e-01  1.0880262e+00
   1.0999410e+00 -6.4601010e-01]
 [-2.8001111e-02 -6.0928864e-03  9.8206289e-03 -3.1710796e

In [8]:
network.set_gradients(layer_index=0,indices=x[0],values=x[1],is_set_biases=False)

inside the set gradients from tuple function


In [10]:
weights_new=network.get_weights_gradients(0)
print(weights_new)

[[0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.6406704  0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.4757971  0.        ]
 [0.         0.         0.         0.         0.29945943 0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.39426008 0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         1.099941   0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.5506324
  0.         0.         0.         0.        ]
 [0.         0.800386   0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.    