# Bidirection Compression


## 1. Implement an environment that would emulate the communication of workers with the server


In [1]:
import time
import numpy as np
from sklearn.utils import gen_batches


class NoopServerToWorkerCompressor:
    def __init__(self):
        self.transmitted_coordinated = 0

    def compress(self, X, y):
        self.transmitted_coordinates = X.shape[0] * X.shape[1] + y.shape[0] * y.shape[1]
        return X, y

    def decompress(self, X, y):
        return X, y


class DistributedEnvSimulator:
    def __init__(
        self,
        nabla_f,
        gamma_k,
        worker_to_server_compressor,
        server_to_worker_compressor=NoopServerToWorkerCompressor(),
    ):
        self.nabla_f = nabla_f
        self.gamma_k = gamma_k
        self.worker_to_server_compressor = worker_to_server_compressor
        self.server_to_worker_compressor = server_to_worker_compressor

    def simulate_distributed_env(
        self,
        X,
        y,
        n_devices,
        num_iterations=100,
        eps=None,
    ):
        n_features = X.shape[1]
        weights = np.zeros(n_features)
        device_weights = [np.zeros(n_features) for _ in range(n_devices)]
        convergence = []
        accuracies = []
        execution_time = 0

        L = np.sum(np.linalg.vector_norm(X, axis=1) ** 2) / (4 * X.shape[1])
        X_split, y_split = self._split_data(X, y, int(X.shape[0] / n_devices))

        for iteration in range(num_iterations):
            start_time = time.time()
            aggregated_gradient = np.zeros(n_features)

            for device_idx in range(n_devices):
                X_device = X_split[device_idx]
                y_device = y_split[device_idx]
                weights_device = device_weights[device_idx]

                compressed_gradient, indices = self._execute_on_device(
                    X_device,
                    y_device,
                    device_weights,
                    L,
                )

                aggregated_gradient += self.worker_to_server_compressor.decompress(
                    compressed_gradient, indices, len(X_device)
                )

            gamma = self.gamma_k(L, iteration)
            weights -= gamma * (aggregated_gradient / n_devices)
            execution_time += time.time() - start_time

            accuracy_i = self._estimate_accuracy(X, y, weights)
            accuracies.append(accuracy_i)
            convergence_i = self._estimate_convergence(aggregated_gradient / n_devices)
            convergence.append(convergence_i)

            if eps is not None and accuracy_i < eps:
                break

        transmitted_coordinates_worker_to_server = self.worker_to_server_compressor.transmitted_coordinates
        transmitted_coordinates_server_to_worker = self.server_to_worker_compressor.transmitted_coordinates

        return (
            weights,
            convergence,
            accuracies,
            execution_time,
            transmitted_coordinates_worker_to_server,
            transmitted_coordinates_server_to_worker,
        )

    def _estimate_accuracy(self, X, y, weights):
        y_pred = np.sign(np.dot(X, weights))
        diff = y.astype("int") - y_pred.astype("int")
        false_predictions = len(diff[diff != 0])
        accuracy = 1 - false_predictions / len(y_pred)
        return accuracy

    def _estimate_convergence(self, aggregated_gradient, n_devices):
        return np.linalg.norm(aggregated_gradient / n_devices)

    def _execute_on_device(self, X, y, weights, L):
        lambda_reg = L / 1000
        gradient_device = self.nabla_f(X, y, weights, lambda_reg)

        compressed_gradient, indices = self.worker_to_server_compressor.compress(gradient_device)

        return compressed_gradient, indices

    def _split_data(self, X, y, batch_size):
        X_batched = []
        y_batched = []
        for batch_indices in gen_batches(n=len(X), batch_size=batch_size, min_batch_size=batch_size):
            X_batched.append(X[batch_indices])
            y_batched.append(y[batch_indices])
        return X_batched, y_batched

## 2. Implement CGD