# Deploy comunication compression scheme in FedLab

This tutorial provides comprehensive examples about implementing a communication efficiency scheme in FedLab. 

We take the baseline gradient compression algorithms as examples (top-k for gradient sparsification and QSGD for gradient quantization).

## Compress example

In [1]:
from fedlab.contrib.compressor.quantization import QSGDCompressor
from fedlab.contrib.compressor.topk import TopkCompressor
import torch

tpk_compressor = TopkCompressor(compress_ratio=0.05) # top 5% gradient 
qsgd_compressor = QSGDCompressor(n_bit=8) 

In [2]:
# top-k
tensor = torch.randn(size=(100,))
shape = tensor.shape
print("To be compressed tensor:", tensor)

# compress
values, indices = tpk_compressor.compress(tensor)
print("Compressed results top-k values:",values)
print("Compressed results top-k indices:", indices)

# decompress
decompressed = tpk_compressor.decompress(values, indices, shape)
print("Decompressed results:", decompressed)

To be compressed tensor: tensor([-0.2551,  0.1723, -0.3045,  0.1142,  1.4212, -0.0035, -1.4607, -0.8500,
         2.1810,  0.5937,  0.2021,  0.2862,  1.3127,  0.1277, -1.0769,  0.2209,
         1.1807,  0.5525,  0.1321,  0.0979,  0.3020,  0.1013,  3.4029, -0.2621,
         0.0406,  0.6202, -0.3686, -1.9036,  0.9475,  1.3832,  0.4061, -0.4844,
        -1.1010, -0.8294, -0.5312,  0.4947,  0.0112,  0.9657,  1.2832,  1.8749,
         0.3791,  0.3394,  2.7730,  0.5408,  1.9294, -0.6163,  1.0113, -1.9415,
         0.1310, -2.1150, -0.2764, -0.3573, -1.2982,  1.3719,  1.1704,  1.5341,
         0.8125,  0.1082, -0.0353,  1.2760,  1.4667,  0.3195,  0.2799, -1.2045,
         0.6300, -1.2554, -0.2688, -0.0793, -1.0269, -1.2192, -0.0565,  1.5932,
         0.4082, -0.0509, -0.2076, -1.4379, -0.3355, -0.0201, -0.0423, -0.6073,
        -0.3837,  0.0933, -1.0693, -0.5408, -0.3400,  0.9970, -0.5036,  1.5442,
         1.2167,  0.6450, -1.3357, -0.5427, -0.1329,  1.1694,  0.3605,  2.4106,
        -1.8797

In [3]:
# qsgd
tensor = torch.randn(size=(100,))
shape = tensor.shape
print("To be compressed tensor:", tensor)

# compress
norm, signs, values = qsgd_compressor.compress(tensor)
print("Compressed results QSGD norm:", norm)
print("Compressed results QSGD signs:", signs)
print("Compressed results QSGD values:", values)


To be compressed tensor: tensor([ 0.8138, -0.4996, -0.1319, -0.2596, -0.1408,  2.0757,  0.2944,  1.2920,
        -1.4043,  0.1464, -0.5552,  0.1523, -0.7994,  0.8580,  0.4179, -0.0481,
         0.7103,  0.3520, -0.5581,  0.4171, -1.1006,  0.9965, -0.4987,  0.7205,
        -0.8855, -0.3032,  0.1235,  0.0968,  0.2313, -1.7149, -1.0833,  0.3717,
        -1.2076, -0.4502, -1.8539,  0.9081,  1.7861,  0.6766,  0.2267, -0.4099,
         1.8331, -1.4917,  0.2764,  0.2297, -0.3605, -0.0055,  0.6125,  1.4960,
         0.5952,  2.1638, -2.5469, -0.8668,  0.3518,  0.6117, -1.7751, -0.7952,
         2.3184, -0.4247,  0.9324, -0.7573, -0.5979,  1.0592,  0.5357, -0.7331,
         0.9007, -0.4086, -0.2537, -4.1050, -0.1990,  0.5406, -0.2515, -0.5306,
        -0.7901, -0.8744,  1.3315, -0.8028, -0.3142,  0.5809,  1.2843, -1.4680,
         0.2522,  0.6385, -1.1771,  0.5757, -0.8276, -0.5795,  1.8992, -1.0554,
         1.5926, -0.9995, -0.9659,  0.4166,  0.1454,  1.2823, -0.4326,  1.0771,
        -0.5909

In [4]:
# decompress
decompressed = qsgd_compressor.decompress([norm, signs, values])
print("Decompressed results:", decompressed)

Decompressed results: tensor([ 0.8018, -0.4971, -0.1283, -0.2726, -0.1443,  2.0686,  0.2886,  1.2828,
        -1.4111,  0.1443, -0.5612,  0.1604, -0.8018,  0.8499,  0.4169, -0.0481,
         0.7216,  0.3367, -0.5612,  0.4169, -1.1064,  0.9942, -0.4971,  0.7216,
        -0.8819, -0.3047,  0.1283,  0.0962,  0.2245, -1.7158, -1.0744,  0.3688,
        -1.2187, -0.4490, -1.8441,  0.8980,  1.7799,  0.6895,  0.2245, -0.4169,
         1.8280, -1.4913,  0.2726,  0.2405, -0.3528, -0.0160,  0.6093,  1.5073,
         0.5933,  2.1648, -2.5496, -0.8659,  0.3528,  0.6093, -1.7799, -0.7857,
         2.3251, -0.4169,  0.9300, -0.7537, -0.6093,  1.0583,  0.5452, -0.7376,
         0.8980, -0.4169, -0.2566, -4.1050, -0.2085,  0.5452, -0.2566, -0.5292,
        -0.8018, -0.8659,  1.3309, -0.8018, -0.3207,  0.5773,  1.2828, -1.4592,
         0.2405,  0.6254, -1.1706,  0.5773, -0.8338, -0.5773,  1.8922, -1.0583,
         1.5875, -0.9942, -0.9782,  0.4169,  0.1443,  1.2828, -0.4330,  1.0744,
        -0.5933,  

## Use compressor in federated learning

For example on the client side, we could compress the tensors are to compressed and upload the compressed results to server. And server could decompress the tensors follows the compression agreements.

In jupyter notebook, we take the standalone scenario as example.

In [5]:
from fedlab.contrib.algorithm.basic_client import SGDSerialClientTrainer, SGDClientTrainer
from fedlab.contrib.algorithm.basic_server import SyncServerHandler

class CompressSerialClientTrainer(SGDSerialClientTrainer):
    def setup_compressor(self, compressor):
        #self.compressor = TopkCompressor(compress_ratio=k)
        self.compressor = compressor

    @property
    def uplink_package(self):
        package = super().uplink_package
        new_package = []
        for content in package:
            pack = [self.compressor.compress(content[0])]
            new_package.append(pack)
        return new_package

class CompressServerHandeler(SyncServerHandler):
    def setup_compressor(self, compressor, type):
        #self.compressor = TopkCompressor(compress_ratio=k)
        self.compressor = compressor
        self.type = type

    def load(self, payload) -> bool:
        if self.type == "topk":
            values, indices = payload[0]
            decompressed_payload = self.compressor.decompress(values, indices, self.model_parameters.shape)

        if self.type == "qsgd":
            n, s, l = payload[0]
            decompressed_payload = self.compressor.decompress((n,s,l))
        
        return super().load([decompressed_payload])

In [6]:
# main, this part we follow the pipeline in pipeline_tutorial.ipynb
# But replace the hander and trainer by the above defined for communication compression

# configuration
from opcode import cmp_op
from munch import Munch
from fedlab.models.mlp import MLP

model = MLP(784, 10)
args = Munch

args.total_client = 100
args.alpha = 0.5
args.seed = 42
args.preprocess = False
args.cuda = True
args.cmp_op = "qsgd" # "topk, qsgd"

args.k = 0.1 # topk
args.bit = 8 # qsgd

if args.cmp_op == "topk":
    compressor = TopkCompressor(args.k)

if args.cmp_op == "qsgd":
    compressor = QSGDCompressor(args.bit)

from torchvision import transforms
from fedlab.contrib.dataset.partitioned_mnist import PartitionedMNIST

fed_mnist = PartitionedMNIST(root="./datasets/mnist/",
                             path="./datasets/mnist/fedmnist/",
                             num_clients=args.total_client,
                             partition="noniid-labeldir",
                             dir_alpha=args.alpha,
                             seed=args.seed,
                             preprocess=args.preprocess,
                             download=True,
                             verbose=True,
                             transform=transforms.Compose([
                                 transforms.ToPILImage(),
                                 transforms.ToTensor()
                             ]))

dataset = fed_mnist.get_dataset(0)  # get the 0-th client's dataset
dataloader = fed_mnist.get_dataloader(
    0,
    batch_size=128)  # get the 0-th client's dataset loader with batch size 128


In [7]:
# client
from fedlab.contrib.algorithm.basic_client import SGDSerialClientTrainer, SGDClientTrainer

# local train configuration
args.epochs = 5
args.batch_size = 128
args.lr = 0.1

trainer = CompressSerialClientTrainer(model, args.total_client,
                                 cuda=args.cuda)  # serial trainer
# trainer = SGDClientTrainer(model, cuda=True) # single trainer

trainer.setup_dataset(fed_mnist)
trainer.setup_optim(args.epochs, args.batch_size, args.lr)
trainer.setup_compressor(compressor)

# server
from fedlab.contrib.algorithm.basic_server import SyncServerHandler

# global configuration
args.com_round = 10
args.sample_ratio = 0.1

handler = CompressServerHandeler(model=model,
                            global_round=args.com_round,
                            sample_ratio=args.sample_ratio,
                            cuda=args.cuda)
handler.setup_compressor(compressor, args.cmp_op)

In [8]:
from fedlab.utils.functional import evaluate
from fedlab.core.standalone import StandalonePipeline

from torch import nn
from torch.utils.data import DataLoader
import torchvision

class EvalPipeline(StandalonePipeline):
    def __init__(self, handler, trainer, test_loader):
        super().__init__(handler, trainer)
        self.test_loader = test_loader

    def main(self):
        while self.handler.if_stop is False:
            # server side
            sampled_clients = self.handler.sample_clients()
            broadcast = self.handler.downlink_package

            # client side
            self.trainer.local_process(broadcast, sampled_clients)
            uploads = self.trainer.uplink_package

            # server side
            for pack in uploads:
                self.handler.load(pack)

            loss, acc = evaluate(self.handler.model, nn.CrossEntropyLoss(),
                                 self.test_loader)
            print("loss {:.4f}, test accuracy {:.4f}".format(loss, acc))


test_data = torchvision.datasets.MNIST(root="./tests/data/mnist/",
                                       train=False,
                                       transform=transforms.ToTensor())
test_loader = DataLoader(test_data, batch_size=1024)

standalone_eval = EvalPipeline(handler=handler,
                               trainer=trainer,
                               test_loader=test_loader)
standalone_eval.main()

loss 21.9152, test accuracy 0.2445
loss 17.1248, test accuracy 0.5201
loss 14.0367, test accuracy 0.5140
loss 9.1731, test accuracy 0.7470
loss 7.4824, test accuracy 0.8066
loss 6.3856, test accuracy 0.8209
loss 5.9895, test accuracy 0.8164
loss 5.3765, test accuracy 0.8324
loss 4.6479, test accuracy 0.8707
loss 4.7858, test accuracy 0.8604
