# Activation and Gradients Compression for Model Parallelism

Experiment code for resnet&cifar training with compression.

## Imports, Installs, Versions

In [None]:
!pip install wandb

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import sys
import random
import copy
import wandb
# import lovely_tensors as lt
# lt.monkey_patch()
print("Python Version: ", sys.version)
print("PyTorch Version: ", torch.__version__)
print("Torchvision Version: ", torchvision.__version__)
print("WandB version: ", wandb.__version__)

seed = 801
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

## Run Config

In [None]:
if 'KAGGLE_URL_BASE' in os.environ:
    kaggle_run = True
else:
    kaggle_run = False
print(f'Run on kaggle: {kaggle_run}')

In [None]:
if kaggle_run:
  from kaggle_secrets import UserSecretsClient
  user_secrets = UserSecretsClient()
  secret_value_0 = user_secrets.get_secret("WANDB_API_KEY")
  os.environ["WANDB_API_KEY"] = secret_value_0
# wandb.login()
use_wandb = True

In [None]:
# compression_config = {
#     'layer1':{
#         'forward': 'id',
#         'forward-EF': False,
#         'forward-params': {},
#         'backward': 'id',
#         'backward-EF': False,
#         'backward-params': {},
#     },
#     'layer2':{
#         'forward': 'id',
#         'forward-EF': False,
#         'forward-params': {},
#         'backward': 'id',
#         'backward-EF': False,
#         'backward-params': {},
#     },
#     'layer3':{
#         'forward': 'id',
#         'forward-EF': False,
#         'forward-params': {},
#         'backward': 'id',
#         'backward-EF': False,
#         'backward-params': {},
#     },
# }

In [None]:
"""
Dict-based config for compression
"""
compression_config = {
    'layer1':{
        'forward': 'topk',
        'forward-EF': False,
        'forward-params': {'topk': 0.05},
        'backward': 'topk',
        'backward-EF': False,
        'backward-params': {'topk': 0.1},
    },
    'layer2':{
        'forward': 'topk',
        'forward-EF': False,
        'forward-params': {'topk': 0.05},
        'backward': 'topk',
        'backward-EF': False,
        'backward-params': {'topk': 0.1},
    },
    'layer3':{
        'forward': 'topk',
        'forward-EF': False,
        'forward-params': {'topk': 0.05},
        'backward': 'topk',
        'backward-EF': False,
        'backward-params': {'topk': 0.1},
    },
}

In [None]:
config = {
  "learning_rate": 0.01,
  "architecture": "resnet-18",
  "dataset": "CIFAR10",
  "epochs": 100,
  "compression-config": compression_config,
  "AC-SGD": False,
  "seed": seed,
}

In [None]:
run = wandb.init(
    project="resnet-cifar-compression",
    name='fw-top5-bw-top10',
    mode='online' if use_wandb else 'disabled',
    config=config,
)

Training implementation is based on:
https://github.com/kuangliu/pytorch-cifar

## Dataset

In [None]:
from torch.utils.data import Dataset

class CIFAR10(Dataset):
    def __init__(self, train=True, with_idx=False):

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform = transform_train if train else transform_test
        self.with_idx = with_idx
        self.cifar10 = torchvision.datasets.CIFAR10(
            root='./data',
            train=train,
            download=True,
            transform=transform)

    def __getitem__(self, index):
        data, target = self.cifar10[index]
        if self.with_idx:
            return data, target, index
        else:
            return data, target

    def __len__(self):
        return len(self.cifar10)

In [None]:
with_idx = config["AC-SGD"]
batch_size = 128

trainloader = torch.utils.data.DataLoader(
    CIFAR10(with_idx=with_idx),
    batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True
)

testloader = torch.utils.data.DataLoader(
    CIFAR10(train=False, with_idx=with_idx),
    batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True
)

# classes = ('plane', 'car', 'bird', 'cat',
#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(torch.cuda.get_device_name(device))

## Compression Functions

### TopK

In [None]:
# Function to compress input vector
# Returns vector of the same shape, zero-out lowest(absolute value) (1-k)*100% values
def compress_topk(input, topk=0.1):
    input_2d = input.flatten(start_dim=1)
    n_lowest = int(input_2d.shape[1] * (1 - topk))
    input_2d_abs = torch.abs(input_2d)
    pivot = torch.kthvalue(input_2d_abs,k=n_lowest, keepdim=True).values
    mask = input_2d_abs <= pivot
    out = input_2d.masked_fill_(mask, 0)
    return out.reshape(input.shape)

### Custom Quantization

In [None]:
def quantize_custom(input, k=2):
    x = input.flatten(start_dim=1)
    x_min = x.min()
    x_max = x.max()
    x_norm = (x - x_min) / (x_max - x_min) # from 0 to 1
    val = 1/(2**k-1) * torch.round((2**k - 1) * x_norm) # quantized
    out = val * (x_max - x_min) + x_min # back to scale
    return out.reshape(input.shape)

### AC-SGD Quantization

reference:https://github.com/DS3Lab/AC-SGD/blob/6c1d90093f700ed6716dda3e77733160f3794148/compress/fixpoint.py#L22

In [None]:
def quantize_ACSGD(x, nbits, scale_method='max', scale_dims=(0,)):
    
    fbits = nbits - 1
    
    if scale_method == 'max':
        # issue: sensitive to outlier points
        scale = x.abs().amax(scale_dims, keepdims=True)
    elif scale_method == 'l2':
        # ~95% confidence interval for normal distribution
        scale = x.pow(2).mean(scale_dims, keepdims=True).sqrt() * 2 
    else:
        raise Exception('unkonwn scale method.')
    # fp16 should be enough
    scale = scale.half()
    x = x / (scale + 1e-6)
    
    x = x.ldexp(torch.tensor(fbits))
    clip_min = -(1<<fbits)
    clip_max = (1<<fbits)-1

    x = x.round()
    x = x.clip(clip_min, clip_max)
    
    x = x - clip_min
    x = x.type(torch.uint8)
    
    return x, scale

def dequantize_ACSGD(x, nbits, scale):
    
    fbits = nbits - 1
    
    clip_min = -(1<<fbits)
    clip_max = (1<<fbits)-1
    
    x = x.float() + clip_min
    
    x = x / (clip_max+1) * scale
    
    return x

# XXX: works only for vectors on GPU
def quantize_dequantize_ACSGD(x, nbits, scale_method='max', scale_dims=(0,)):
    xx = x.flatten(start_dim=1)
    xx, scale = quantize_ACSGD(xx, nbits, scale_method, scale_dims)
    c_x = dequantize_ACSGD(xx, nbits, scale)
    return c_x

## Torch Autograd Compressors

In [None]:
def BaseCompressor(compressor_forward, compressor_backward):

    class myCompressor(torch.autograd.Function):

            @staticmethod
            def forward(ctx, input, compress=True):
                if compressor_forward is None:
                    return input
                x = input.detach().clone()
                ctx.compress = compress
                if compress:
                    c_x = compressor_forward(x)
                else:
                    c_x = x
                return c_x

            @staticmethod
            def backward(ctx, grad_output):
                if compressor_backward is None:
                    return grad_output
                x = grad_output.detach().clone()
                if ctx.compress:
                    c_x = compressor_backward(x)
                else:
                    c_x = x
                return c_x, None

    return myCompressor.apply

In [None]:
def EFCompressor(compressor_forward, fwd_error_buffer, compressor_backward, bckwd_error_buffer):
    
    class EFCompression(torch.autograd.Function):

            @staticmethod
            def forward(ctx, input, compress):
                if compressor_forward is None:
                    return input
                x = input.detach().clone()
                ctx.compress = compress
                if compress:
#                     e = fwd_error_buffer # get error from dict
                    c_x = compressor_forward(x) # compress x+e
#                     e = e + x - c_x # update e
#                     fwd_error_buffer[:] = e # set e back to dict
                    return c_x
                else:
                    c_x = x
                return c_x


            @staticmethod
            def backward(ctx, grad_output):
                if compressor_backward is None:
                    return grad_output
                x = grad_output.detach().clone()
                if ctx.compress:
                    e = bckwd_error_buffer # get error from dict
                    c_x = compressor_backward(x + e) # compress x+e
                    e = e + x - c_x # update e
                    bckwd_error_buffer[:] = e # set e back to dict
                else:
                    c_x = x
                return c_x, None

    return EFCompression.apply

In [None]:
def EF21Compressor(compressor_forward, fwd_error_buffer, compressor_backward, bckwd_error_buffer):
    
    class EF21Compression(torch.autograd.Function):

            @staticmethod
            def forward(ctx, input, compress):
                if compressor_forward is None:
                    return input
                x = input.detach().clone()
                ctx.compress = compress
                if compress:
                    A0 = fwd_error_buffer # get old activations from dict
                    c_x = compress_forward(x - A0) # compress P1 - A0
                    A1 = A0 + c_x # update Ai
                    fwd_error_buffer[:] = A1 # store A1 back to dict
                    return A1
                else:
                    c_x = x
                return c_x


            @staticmethod
            def backward(ctx, grad_output):
                if compressor_backward is None:
                    return grad_output
                x = grad_output.detach().clone()
                if ctx.compress:
                    A0 = bckwd_error_buffer # get old activations from dict
                    c_x, _ = BaseCompressor(x - A0) # compress diff of activations
                    A1 = A0 + c_x # send old + difference compressed
                    bckwd_error_buffer[:] = A1 # store A1 back to dict
                    return A1, None
                else:
                    c_x = x
                return c_x, None

    return EF21Compression.apply

## Compressor Module

In [None]:
class Compressor(nn.Module):
    
    compressors_dict = {
        'id': None,
        'topk': compress_topk,
        'quantization_simple': quantize_custom,
        'quantization_acsgd': quantize_dequantize_ACSGD
    }
    
    def _get_compressor(self, compressor_id, **compressor_params):
        if compressor_id in self.compressors_dict:
            if compressor_id == 'id':
                compressor = lambda x: x
            else:
                compressor = lambda x: self.compressors_dict[compressor_id](x, **compressor_params)
        else:
            raise ValueError(f'Compressor with name {compressor_id} Not found')
        return compressor
    
    def __init__(self, input_shape, forward='id', forward_params={}, backward='id', backward_params={}, EF=False):
        super(Compressor, self).__init__()
        self.input_shape = input_shape
        self.forward_func = forward
        self.forward_params = forward_params
        self.backward_func = backward
        self.backward_params = backward_params
        forward_compression = self._get_compressor(forward, **forward_params)
        backward_compression = self._get_compressor(backward, **backward_params)
        if EF:
            self.EF_forward_buffer = torch.zeros(input_shape, device=device)
            self.EF_backward_buffer = torch.zeros(input_shape, device=device)
            self.compressor = EFCompressor(
                forward_compression, self.EF_forward_buffer,
                backward_compression, self.EF_backward_buffer
            )
        else:
            self.compressor = BaseCompressor(
                forward_compression, backward_compression
            )
        

    def forward(self, x, compress=True):
        out = self.compressor(x, compress)
        return out
    
#     def __repr__(self):
#         return('forward={} with params {} backward={} with params {}'.format(
#             self.forward_func,
#             (*[str(x)+'='+str(y) for x, y in self.forward_params.items()],) if self.forward_params else "",
#             self.backward_func,
#             (*[str(x)+'='+str(y) for x, y in self.backward_params.items()],) if self.backward_params else "",
#         ))

## Model

ResNet-18 adopted from [https://github.com/kuangliu/pytorch-cifar](pytorch-cifar)

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

## Model With Compression

In [None]:
class ResNetWithCompression(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, compression_config=None):
        super(ResNetWithCompression, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)
        
        shapes = [(batch_size, 64, 32, 32), (batch_size, 128, 16, 16), (batch_size, 256, 8, 8)]
        
        for (module_name, compression_params), input_shape in zip(compression_config.items(), shapes):
            compressor = Compressor(
                input_shape,
                compression_params['forward'], compression_params['forward-params'],
                compression_params['backward'], compression_params['backward-params'],
            )
            setattr(self, module_name+'_compression', compressor)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, compress=True):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1_compression(self.layer1(out), compress=compress)
        out = self.layer2_compression(self.layer2(out), compress=compress)
        out = self.layer3_compression(self.layer3(out), compress=compress)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18WithCompression(compression_config):
    return ResNetWithCompression(BasicBlock, [2, 2, 2, 2], compression_config=compression_config)

## Training Parameters

In [None]:
net = ResNet18WithCompression(compression_config)
net.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=config['learning_rate'],
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [None]:
best_acc = 0
best_acc_compressed = 0

## Train & Test Loop

In [None]:
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs, compress=True)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx % 100 == 99:
            print('Train: Loss: %.3f | Acc: %.3f%% (%d/%d)'  % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    acc = 100.*correct/total
    wandb.log({
        'train_loss': train_loss,
        'train_acc': acc
    }, commit=False)

In [None]:
def test(epoch):
    global best_acc, best_acc_compressed
    net.eval()
    test_loss = 0
    test_loss_compressed = 0
    correct = 0
    correct_compressed = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # no compression test
            outputs = net(inputs, compress=False)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            
            # compressed test
            outputs = net(inputs, compress=True)
            loss = criterion(outputs, targets)
            test_loss_compressed += loss.item()
            _, predicted = outputs.max(1)
            correct_compressed += predicted.eq(targets).sum().item()
            
            
            
            total += targets.size(0)
    acc = 100.*correct/total
    acc_compressed = 100.*correct_compressed/total
    print('Test Uncompressed: Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/len(testloader), acc, correct, total))
    print('Test Compressed: Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss_compressed/len(testloader), acc_compressed, correct_compressed, total))

    wandb.log({
        'test_loss': test_loss,
        'test_acc': acc,
        'test_loss_compressed': test_loss_compressed,
        'test_acc_compressed': acc_compressed,
        'epoch': epoch
    })
    if acc > best_acc:
        best_acc = acc
        print(f'New best acc {acc}')
    if acc_compressed > best_acc_compressed:
        best_acc_compressed = acc_compressed
        print(f'New best acc_compressed {acc_compressed}')

## Run Model!

In [None]:
import time
for epoch in range(config['epochs']):
    s = time.time()
    train(epoch)
    test(epoch)
    scheduler.step()
    e = time.time()
    print('Time: ', e - s)

In [None]:
wandb.run.summary['best_test_accuracy'] = best_acc 
wandb.run.summary['best_test_accuracy_compressed'] = best_acc_compressed
wandb.finish()