In [1]:
import pytest

pytestmark = [pytest.mark.unit, pytest.mark.release]

import numpy as np
from thirdai import bolt

from utils import (
    gen_single_sparse_layer_network,
    gen_numpy_training_data,
    get_categorical_acc,
    train_network,
    build_simple_distributed_bolt_network,
    train_single_node_distributed_network,
)


# gets compressed gradients as (indices,values) tuple and then asserts that the gradient matrix after setting is the same as the tuple


In [2]:
network = build_simple_distributed_bolt_network(sparsity=1.0, n_classes=10)
examples, labels = gen_numpy_training_data(n_classes=10, n_samples=1000)
train_single_node_distributed_network(
    network, examples, labels, epochs=1, update_parameters=False
)

indices, values = network.get_indexed_sketch_for_gradients(
    layer_index=0, compression_density=0.1, sketch_biases=False, seed_for_hashing=0
)
print(indices,values)
network.set_gradients_from_indices_values(
    layer_index=0, indices=indices, values=values, set_biases=True
)

network_weights = network.get_weights_gradients(0).flatten()
np.add.at(network_weights, indices, -1 * values)
norm_after_subtracting_gradients = np.linalg.norm(network_weights)

assert norm_after_subtracting_gradients == 0


Initializing Bolt network...
InputLayer (Layer 0): dim=10
FullyConnectedLayer (Layer 1): dim=50, sparsity=1, act_func=ReLU
FullyConnectedLayer (Layer 2): dim=10, sparsity=1, act_func=Softmax
Initialized Network in 0 seconds
Distributed Network initialization done on this Node
[491 293 138  75   0 283 396  45 367   0 170   0   0 420  42   0   0   0
 279 161 499   0  87 438 119  49  16  60   0 179   0 142 437 263  15  86
 266   0   0 267  11 176   0 166 307   0 203 110  79 465] [-0.02549389 -0.01961542 -0.01972435  0.02007152  0.         -0.01751803
  0.02181056  0.0225569   0.03090289  0.          0.02882907  0.
  0.          0.0220595  -0.0200375   0.          0.          0.
  0.01677938  0.0254018   0.01869166  0.         -0.01916217 -0.0210577
  0.02152697  0.0174329   0.03116976  0.0197744   0.         -0.01706735
  0.          0.02939146  0.02065227 -0.02022854 -0.01776834 -0.02401814
  0.02515446  0.          0.          0.01765273  0.01659583 -0.02257535
  0.          0.02209039 

AssertionError: 