# Applying BlueFog on Deep Learning problem(High Level API Introduction)

All previous sections we focused on the low-level API in the BlueFog, which is great for flexible algorithm design and research. Give a quick summary here:

- Basic static topology property and its manipulation
- Collective communication such as broadcast and allreduce
- Topology based Neighborhood communication such as neighbor_allreduce.
- Blocking versus non-blocking operation
- Dynamic topology and its corresponding
- Asynchronous operation through window object
- Examine numerous algorithms and their performance under different scenarios
- etc.

However, it can be boilerplate if you want to apply one certain algorithm on different tasks. Further, it is also tricky to write a efficiency code combining above mentioned concepts. This becomes even worse in the deep learning problem. Backpropagation property of neural network makes that the gradient can be efficiently calculated. Backpropagation also implied that the gradient is calculated (approximately) in layer-by-layer style, in contrast to one global stochastic (sub-)gradient we encountered in the optimization. Further, this layer-wise computation provides a great opportunity to overlap the communication and comptuation for minimizing the trainning time, which is a crucial. Clearly, writing the code to address these concern correctly and efficient is not easy. Hence, BlueFog further provides the high level APIs, which can be directly applied on the `torch.Optimizer` directly. 

In this section, we will focus in applying high level APIs of BlueFog on Deep Learning problem, mainly the decentralized trainning task.
Before we demystify how we implement the High-Level API in BlueFog, let's see the example of using them to write distributed trainning of ResNet-18 model over CIFAR-10 easily. (*Note, although this example is relative small, it still can be time-consuming and drain tons of computation resources if you want to train them on CPUs.*)


In [None]:
import os

import ipyparallel as ipp
import networkx as nx
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms, models

%matplotlib inline

In [None]:
rc = ipp.Client(profile="bluefog")
rc.block=True
rc.ids

In [None]:
# Down the CIFAR10 Dataset if not available
# Since the dataset is smaller enough, we just load it in-memory.
train_dataset = datasets.CIFAR10(
    os.path.join(os.getcwd(), "..", "data"),
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ]),
)

val_dataset = datasets.CIFAR10(
    os.path.join(cwd_folder_loc, "..", "data"),
    train=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ]),
)

In [None]:
# Distribute the data into each worker.
# Note we push the full dataset into each worker is just for simplicity.
# Each worker only read the partial of dataset later.
_ = rc[:].push({"train_dataset": train_dataset, "val_dataset": val_dataset})

In [None]:
%%px
import torch
import bluefog.torch as bf
from bluefog.common import topology_util

seed = 2021
bf.init()
torch.manual_seed(seed)
run_on_cuda = torch.cuda.is_available()
if run_on_cuda:
    print("using cuda.")
    # Bluefog: pin GPU to local rank.
    device_id = (bf.local_rank() if bf.nccl_built() else bf.local_rank() %
                 torch.cuda.device_count())
    torch.cuda.set_device(device_id)
    torch.cuda.manual_seed(seed)
else:
    print("using cpu")

In [None]:
%%px
# Prepare the distributed loader for dataset.
batch_size = 32
val_batch_size = 1024
kwargs = {"num_workers": 4, "pin_memory": True} if run_on_cuda else {}

train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset, num_replicas=bf.size(), rank=bf.rank())
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           sampler=train_sampler,
                                           **kwargs)

val_sampler = torch.utils.data.distributed.DistributedSampler(
    val_dataset, num_replicas=bf.size(), rank=bf.rank())
val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=val_batch_size,
                                         sampler=val_sampler,
                                         **kwargs)

In [None]:
%%px
model = models.resnet18(num_classes=10)
if run_on_cuda:
    model.cuda()

# Scale learning rate by the number of GPUs.
base_lr = 0.0125
momentum = 0.9
weight_decay = 0.00005

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=(base_lr * bf.size()),
    momentum=momentum,
    weight_decay=weight_decay,
)

## Wrap the torch standard optimizer into BlueFog distributed one

In [None]:
%%px
atc_style = False
base_dist_optimizer = (bf.DistributedAdaptThenCombineOptimizer if atc_style
                       else bf.DistributedAdaptWithCombineOptimizer)
optimizer = base_dist_optimizer(
    optimizer,
    model=model,
    communication_type=bf.CommunicationType.neighbor_allreduce)

# Bluefog: broadcast parameters & optimizer state.
bf.broadcast_parameters(model.state_dict(), root_rank=0)
bf.broadcast_optimizer_state(optimizer, root_rank=0)

In [None]:
def accuracy(output, target):
    # get the index of the max log-probability
    pred = output.max(1, keepdim=True)[1]
    return pred.eq(target.view_as(pred)).cpu().float().mean()


class Metric(object):
    def __init__(self, name):
        self.name = name
        self.sum = torch.tensor(0.0)  # pylint: disable=not-callable
        self.n = torch.tensor(0.0)  # pylint: disable=not-callable

    def update(self, val):
        self.sum += bf.allreduce(val.detach().cpu(), name=self.name)
        self.n += 1

    @property
    def avg(self):
        return self.sum / self.n


dynamic_neighbor_allreduce_gen = topology_util.GetDynamicOnePeerSendRecvRanks(
    bf.load_topology(), bf.rank())
def dynamic_topology_update(epoch, batch_idx):
    send_neighbors, recv_neighbors = next(dynamic_neighbor_allreduce_gen)
    optimizer.send_neighbors = send_neighbors
    optimizer.neighbor_weights = {
        r: 1 / (len(recv_neighbors) + 1)
        for r in recv_neighbors
    }
    optimizer.self_weight = 1 / (len(recv_neighbors) + 1)


def adjust_learning_rate(epoch, batch_idx):
    if epoch < 5:  # warmup_epochs
        epoch += float(batch_idx + 1) / len(train_loader)
        lr_adj = 1.0 / bf.size() * (epoch *
                                    (bf.size() - 1) / 5 + 1)
    elif epoch < 30:
        lr_adj = 1.0
    elif epoch < 60:
        lr_adj = 1e-1
    elif epoch < 80:
        lr_adj = 1e-2
    else:
        lr_adj = 1e-3
    for param_group in optimizer.param_groups:
        param_group["lr"] = (base_lr * bf.size() * lr_adj)

In [None]:
%%px
# Check how to show tqdm correctly here???

def train(epoch):
    model.train()
    train_sampler.set_epoch(epoch)
    train_loss = Metric("train_loss")
    train_accuracy = Metric("train_accuracy")

    for batch_idx, (data, target) in enumerate(train_loader):
        adjust_learning_rate(epoch, batch_idx)
        dynamic_topology_update(epoch, batch_idx)

        if run_on_cuda:
            data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        train_accuracy.update(accuracy(output, target))
        loss = F.cross_entropy(output, target)
        train_loss.update(loss)
        # Average gradients among sub-batches
        loss.div_(math.ceil(float(len(data)) / batch_size))
        loss.backward()
        # Gradient is applied across all ranks
        optimizer.step()


def validate(epoch):
    model.eval()
    val_loss = Metric("val_loss")
    val_accuracy = Metric("val_accuracy")

    with torch.no_grad():
        for data, target in val_loader:
            if run_on_cuda:
                data, target = data.cuda(), target.cuda()
            output = model(data)

            val_loss.update(F.cross_entropy(output, target))
            val_accuracy.update(accuracy(output, target))

In [None]:
epochs = 25
for epoch in range(epochs):
    rc[:].push({"epoch": epochs})
    %px train(epoch)
    %px validate(epoch)
    # TODO pull the accuracy and result out from worker.

# Demystify the BlueFog DistributedAdaptThenCombineOptimizer