In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
import torchvision
import torchvision.transforms

from os.path import exists
import time
import matplotlib.pyplot as plt
import numpy as np


Optimizer classes

In [2]:
class signSGD(optim.Optimizer):
  """Original implementation of signSGD from https://github.com/jxbz/signSGD/blob/master/signSGD_zeros.ipynb
  """

  def __init__(self, params, lr=0.01, rand_zero=True):
    defaults = dict(lr=lr)
    self.rand_zero = rand_zero
    super(signSGD, self).__init__(params, defaults)

  def step(self, closure=None):
    """Performs a single optimization step.
    Arguments:
        closure (callable, optional): A closure that reevaluates the model
            and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    for group in self.param_groups:
      for p in group['params']:
        if p.grad is None:
          continue

        # take sign of gradient
        grad = torch.sign(p.grad)

        # randomise zero gradients to ±1
        if self.rand_zero:
          grad[grad==0] = torch.randint_like(grad[grad==0], low=0, high=2)*2 - 1
          assert not (grad==0).any()
        
        # make update
        p.data -= group['lr'] * grad

    return loss


class compressedSGD(optim.Optimizer):
  """Generalized gradient compression using binning.
  """

  def __init__(self, params, lr=0.01, rand_zero=True, num_bits=2, decay_max=1.0, decay_min=1.0):

    if num_bits <= 0 or type(num_bits) != int:
       raise ValueError("Expected num_bits to be a positive integer.")
    if decay_max > 1 or decay_max <= 0 or decay_min > 1 or decay_min <= 0:
      raise ValueError("Expected decay_max and decay_min to be a float in the interval (0,1].")

    defaults = dict(lr=lr)
    super(compressedSGD, self).__init__(params, defaults)
    self.decay_max = decay_max
    self.decay_min = decay_min
    self.max_grad_vals = {}
    self.min_grad_vals = {}
    self.num_bits=num_bits

    for group in self.param_groups:
      for p in group['params']:
        self.max_grad_vals[p] = None
        self.min_grad_vals[p] = None

  def step(self, closure=None):
    """Performs a single optimization step.
    Arguments:
        closure (callable, optional): A closure that reevaluates the model
            and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    # parameter update
    for group in self.param_groups:
      for p in group['params']:
        if p.grad is None:
          continue

        # set the max and min vals for gradients
        if self.max_grad_vals[p] == None:
          self.max_grad_vals[p] = torch.clone(p.grad)
        else:
          self.max_grad_vals[p] = torch.max(p.grad, 
                                            self.max_grad_vals[p] * self.decay_max)

        if self.min_grad_vals[p] == None:
          self.min_grad_vals[p] = torch.clone(p.grad)
        else:
          self.min_grad_vals[p] = torch.min(p.grad, 
                                            self.min_grad_vals[p] * self.decay_min)
        # descretize gradient based on max_grad_val, if gradient > 0
        # and based on min_grad_val, if gradient < 0

        grad = p.grad
        grad[grad > 0] = torch.ceil((2**(self.num_bits-1)) * torch.div(grad[grad > 0], self.max_grad_vals[p][grad > 0]))
        grad[grad < 0] = - torch.ceil((2**(self.num_bits-1)) * torch.div(grad[grad < 0], self.min_grad_vals[p][grad < 0]))
        # make update
        p.data -= group['lr'] * grad

    return loss

Tests for compressedSGN

In [65]:
from torch.autograd import Variable

def test_max():
  A = Variable(torch.tensor([[1.0,2.0],[3.0,4.0]]), requires_grad=True)
  optimizer = compressedSGD([A], lr=1, num_bits=2, decay_max=1.0, decay_min=1.0)

  x = Variable(torch.tensor([[1.0],[1.0]]), requires_grad=False)
  y = A @ x
  loss = torch.sum(y-Variable(torch.tensor([[1.0],[1.0]]), requires_grad=False))
  loss.backward()
  optimizer.step()
  
  assert torch.equal(optimizer.max_grad_vals[A], torch.tensor([[1.0,1.0],[1.0,1.0]]) ), "Expected max grad tensor for A to be [[1.0,1.0],[1.0,1.0]]"
  assert torch.equal(A, torch.tensor([[-1.0,0.0],[1.0,2.0]]) ), "Expected A to be [[-1.0,0.0],[1.0,2.0]]"

  optimizer.zero_grad()
  x = Variable(torch.tensor([[2.0],[2.0]]), requires_grad=False)
  y = A @ x
  loss = torch.sum(y-Variable(torch.tensor([[2.0],[2.0]]), requires_grad=False))
  loss.backward()
  optimizer.step()

  assert torch.equal(optimizer.max_grad_vals[A], torch.tensor([[2.0,2.0],[2.0,2.0]]) ), "Expected max grad tensor for A to be [[2.0,2.0],[2.0,2.0]]"
  assert torch.equal(A, torch.tensor([[-3.0,-2.0],[-1.0,0.0]]) ), "Expected max grad tensor for A to be [[-1.0,0.0],[1.0,2.0]]"

  optimizer.zero_grad()
  x = Variable(torch.tensor([[1.0],[3.0]]), requires_grad=False)
  y = A @ x
  loss = torch.sum(y-Variable(torch.tensor([[1.0],[3.0]]), requires_grad=False))
  loss.backward()
  optimizer.step()

  assert torch.equal(optimizer.max_grad_vals[A], torch.tensor([[2.0,3.0],[2.0,3.0]]) ), "Expected max grad tensor for A to be [[2.0,3.0],[2.0,3.0]]"
  assert torch.equal(A, torch.tensor([[-4.0,-4.0],[-2.0,-2.0]]) ), "Expected max grad tensor for A to be [[-1.0,0.0],[1.0,2.0]]"

def test_min():
  A = Variable(torch.tensor([[1.0,2.0],[3.0,4.0]]), requires_grad=True)
  optimizer = compressedSGD([A], lr=1, num_bits=2, decay_max=1.0, decay_min=1.0)

  x = Variable(torch.tensor([[-1.0],[-1.0]]), requires_grad=False)
  y = A @ x
  loss = torch.sum(y-Variable(torch.tensor([[-1.0],[-1.0]]), requires_grad=False))
  loss.backward()
  optimizer.step()
  
  assert torch.equal(optimizer.min_grad_vals[A], torch.tensor([[-1.0,-1.0],[-1.0,-1.0]]) ), "Expected min grad tensor for A to be [[-1.0,-1.0],[-1.0,-1.0]]"
  assert torch.equal(A, torch.tensor([[3.0,4.0],[5.0,6.0]]) ), "Expected A to be [[3.0,4.0],[5.0,6.0]]"

  optimizer.zero_grad()
  x = Variable(torch.tensor([[-2.0],[-2.0]]), requires_grad=False)
  y = A @ x
  loss = torch.sum(y-Variable(torch.tensor([[2.0],[2.0]]), requires_grad=False))
  loss.backward()
  optimizer.step()

  assert torch.equal(optimizer.min_grad_vals[A], torch.tensor([[-2.0,-2.0],[-2.0,-2.0]]) )
  assert torch.equal(A, torch.tensor([[5.0,6.0],[7.0,8.0]]) )

  optimizer.zero_grad()
  x = Variable(torch.tensor([[-1.0],[-3.0]]), requires_grad=False)
  y = A @ x
  loss = torch.sum(y-Variable(torch.tensor([[1.0],[3.0]]), requires_grad=False))
  loss.backward()
  optimizer.step()

  assert torch.equal(optimizer.min_grad_vals[A], torch.tensor([[-2.0,-3.0],[-2.0,-3.0]]) )
  assert torch.equal(A, torch.tensor([[6.0,8.0],[8.0,10.0]]) )
  
def test_lr():
  A = Variable(torch.tensor([[1.0,2.0],[3.0,4.0]]), requires_grad=True)
  optimizer = compressedSGD([A], lr=0.1, num_bits=2, decay_max=1.0, decay_min=1.0)

  x = Variable(torch.tensor([[-1.0],[-1.0]]), requires_grad=False)
  y = A @ x
  loss = torch.sum(y-Variable(torch.tensor([[-1.0],[-1.0]]), requires_grad=False))
  loss.backward()
  optimizer.step()
  assert torch.equal(A, torch.tensor([[1.2,2.2],[3.2,4.2]]) )

def test_bins():
  A = Variable(torch.tensor([[1.0,2.0],[3.0,4.0]]), requires_grad=True)
  optimizer = compressedSGD([A], lr=1, num_bits=3, decay_max=1.0, decay_min=1.0)

  x = Variable(torch.tensor([[-1.0],[-1.0]]), requires_grad=False)
  y = A @ x
  loss = torch.sum(y-Variable(torch.tensor([[-1.0],[-1.0]]), requires_grad=False))
  loss.backward()
  optimizer.step()
  assert torch.equal(A, torch.tensor([[5.0,6.0],[7.0,8.0]]) )

  optimizer.zero_grad()
  x = Variable(torch.tensor([[-1.0],[-3.0]]), requires_grad=False)
  y = A @ x
  loss = torch.sum(y-Variable(torch.tensor([[1.0],[3.0]]), requires_grad=False))
  loss.backward()
  optimizer.step()
  assert torch.equal(A, torch.tensor([[9.0,10.0],[11.0,12.0]]) )


  optimizer.zero_grad()
  x = Variable(torch.tensor([[-2.0],[-1.0]]), requires_grad=False)
  y = A @ x
  loss = torch.sum(y-Variable(torch.tensor([[1.0],[3.0]]), requires_grad=False))
  loss.backward()
  optimizer.step()
  assert torch.equal(A, torch.tensor([[13.0,12.0],[15.0,14.0]]) )

def test_decay():
  A = Variable(torch.tensor([[1.0,2.0],[3.0,4.0]]), requires_grad=True)
  optimizer = compressedSGD([A], lr=1, num_bits=2, decay_max=0.5, decay_min=0.5)

  x = Variable(torch.tensor([[5.0],[5.0]]), requires_grad=False)
  y = A @ x
  loss = torch.sum(y-Variable(torch.tensor([[5.0],[5.0]]), requires_grad=False))
  loss.backward()
  optimizer.step()

  assert torch.equal(optimizer.max_grad_vals[A], torch.tensor([[5.0,5.0],[5.0,5.0]]) )

  optimizer.zero_grad()
  x = Variable(torch.tensor([[1.0],[1.0]]), requires_grad=False)
  y = A @ x
  loss = torch.sum(y-Variable(torch.tensor([[1.0],[3.0]]), requires_grad=False))
  loss.backward()
  optimizer.step()

  assert torch.equal(optimizer.max_grad_vals[A], torch.tensor([[2.5,2.5],[2.5,2.5]]) )

  optimizer.zero_grad()
  x = Variable(torch.tensor([[5.0],[1.0]]), requires_grad=False)
  y = A @ x
  loss = torch.sum(y-Variable(torch.tensor([[1.0],[3.0]]), requires_grad=False))
  loss.backward()
  optimizer.step()

  assert torch.equal(optimizer.max_grad_vals[A], torch.tensor([[5.0,1.25],[5.0,1.25]]) )
  
def test_compressedSGD():
  test_max()
  test_min()
  test_lr()
  test_bins()
  test_decay()
  
test_compressedSGD()

Models to be trained

In [3]:
class CNN_MNIST(nn.Module):
  def __init__(self, num_classes):
    super(CNN_MNIST, self).__init__()
    # Conv2d(in_channels, out_channels, kernel_size, stride, padding)
    # padding, so the produced convolutions have orginal size
    self.conv = nn.Conv2d(1, 32, 3, 1, 1, padding_mode = 'zeros')

    # MaxPool2D(kernel_size, stride)
    self.maxpool = nn.MaxPool2d(2, stride=2) # stride's default value is kernel size

    self.dropout1 = nn.Dropout(p=0.5)
    self.fc1 = nn.Linear(14*14*32, 512)
    self.dropout2 = nn.Dropout(p=0.5)
    self.fc2 = nn.Linear(512, num_classes)
  
  def forward(self, x):
      x = self.conv(x)
      x = F.relu(x)
      x = self.maxpool(x)
      x = torch.flatten(x,1)
      x = self.dropout1(x)
      x = self.fc1(x)
      x = F.relu(x)
      x = self.dropout2(x)
      x = self.fc2(x)
      return(x)


Functions for training the model.

In [4]:
def train(model, device, data_loader, optimizer, epoch, track_losses = False,
              log_interval=100):
  
  model.train()
  losses = []
  for batch_idx, (data, target) in enumerate(data_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad() # zero gradients for every batch 
    output = model(data)
    loss = F.cross_entropy(output, target) # cross entropy loss 
    loss.backward()
    optimizer.step()

    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(data_loader.dataset),
        100. * batch_idx / len(data_loader), loss.item()))
    
    if track_losses:
      losses.append(loss.data.item())
  
  return losses


def test(model, device, data_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in data_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      # cross entropy loss, sum up batch loss
      test_loss += F.cross_entropy(output, target, reduction='sum').item()
      pred = output.argmax(dim=1, keepdim=True)  # get the index of maximum prediction
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(data_loader.dataset)

  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
      test_loss, correct, len(data_loader.dataset),
      100. * correct / len(data_loader.dataset)))
  return (test_loss,  100. * correct / len(data_loader.dataset))

Main function:

In [5]:
def train_model_MNIST(
         num_classes,
         model_id,
         optimizer_choice = "compressedSGD",
         MODELPATH = './',
         save_interval = 25,
         epochs = 10,
         train_batch_size = 64,
         test_batch_size = 64,
         learning_rate = 0.01,
         use_scheduler = True,
         gamma = 0.7,
         step_size = 10,
         num_bits = 2,
         beta = 1,
         seed = 0,
         track_training_loss = False):
  """
  Loads a model from MODELPATH/<label>_model_<model_id>.pt, if the directory exists, otherwise trains a fresh model.
  Args:
    num_classes (int): Output dimension of the CNN, i.e. number of different classes to predict.
    model_id (string): ModelID, is used to determine the outputfile to save the model.
    optimizer_choice (string): Choice of optimizer to be used. Has to be 'SGD', 'signSGD' or 'compressedSGD'.
    MODELPATH (string): Directory to indicate where the model should be saved.
    save_interval (int): Number of epochs until the model is saved.
    epochs (int): Number of epochs to train, after these epochs the function terminates.
    train_batch_size (int): Batchsize when training the model.
    test_batch_size (int): Batchsize when evaluating the model.
    learning_rate (float): Learning rate used in the optimizer.
    use_scheduler (bool): Boolean indicating if an exponential scheduler should be used to decay weight of training paramerters. E.g. for decaying learning rate.
    gamma (float): Decay factor if scheduler is used.
    step_size (int): Number of steps after which the weight decay of the scheduler is applied.  
    num_bits (int): Number of bits for the communication in compressedSGD. Based on this number, the number of bins is 2^num_bits (+ 1).
      (+1) if 0 is included when binning.
    beta (float): Decay factor for max gradient and min gradient in compressedSGD.
    seed (int): Random seed to be set before training.
    track_training_loss (bool): Boolean indicating if training loss should be tracked and returned.
  Out:
    Returns the training losses, if tracked.
    Saves the model to at 'MODELPATH/<label>_model_<model_id>.pt'
  """
  if optimizer_choice not in {"SGD", "signSGD", "compressedSGD"}:
       raise ValueError("Must choose an optimizer out of 'SGD', 'signSGD' or 'compressedSGD'.")

  PATH = MODELPATH + "_model_" + model_id + ".pt"

  torch.manual_seed(seed)

  train_kwargs = {'batch_size': train_batch_size,
                  'shuffle': True}
  test_kwargs = {'batch_size': test_batch_size}
  # preprocess the data

  # load the training and test data
  training_data = torchvision.datasets.MNIST(
    root = 'data',
    train = True,                         
    transform = torchvision.transforms.ToTensor(), 
    download = True)

  test_data = torchvision.datasets.MNIST(
    root = 'data', 
    train = False, 
    transform =  torchvision.transforms.ToTensor())
  
  # identify GPU usage
  use_cuda = torch.cuda.is_available()
  print("Cuda Available: " + str(use_cuda))
  if use_cuda:
    print("GPU Count: " + str(torch.cuda.device_count()))

  device = torch.device("cuda" if use_cuda else "cpu")
  if use_cuda:
    cuda_kwargs = {'num_workers': 1,
                    'pin_memory': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)


  train_loader = torch.utils.data.DataLoader(training_data,**train_kwargs)
  test_loader = torch.utils.data.DataLoader(test_data, **test_kwargs)
  
  model = CNN_MNIST(num_classes)
  
  

  # load parameters, if file exists otherwise initialize weights
  epoch_num = 0
  validation_errors = []
  if exists(PATH):
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch_num = checkpoint['epoch']
    validation_errors = checkpoint['loss']
  else:
    model.to(device)
    
  # initialize the optimizer
  optimizer = None
  if optimizer_choice == "SGD":
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
  elif optimizer_choice == "signSGD":
    optimizer = signSGD(model.parameters(), lr=learning_rate)
  elif optimizer_choice == "compressedSGD":
    optimizer = compressedSGD(model.parameters(), lr=learning_rate, num_bits=num_bits, decay_max=beta, decay_min=beta) 

  
  # Decays the learning rate of each parameter group by gamma every step_size epochs.
  scheduler = None
  if use_scheduler:
     scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma) 

  # for statistics
  t0 = time.time()
  tlast = t0

  for epoch in range(epoch_num + 1, epoch_num + epochs + 1):
    loss = []
    if track_training_loss:
      loss += train(model, device, train_loader, optimizer, epoch, track_losses = True)
    else:
      train(model, device, train_loader, optimizer, epoch, track_losses = False)
    verror = test(model, device, test_loader)
    validation_errors.append(verror)

    td = time.time() - tlast
    tds = time.time() - t0
    tds_m = int(tds // 60)
    tds_s = int(tds) % 60
    tlast = time.time()

    with open("report.txt", "a") as myfile:
      # epoch, loss, accuracy
      myfile.write('{},{:.4f},{:.4f}%,{:8.2f}s,{:8d}m{:d}s\n'.format(epoch, verror[0], verror[1], td, tds_m, tds_s))
    print("{:8.2f}s since last epoch\n{:8d}m{:d}s since start".format((td), tds_m, tds_s))

    if epoch % save_interval == 0:
      torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'loss': validation_errors,
            }, PATH)
    if use_scheduler:
      # Decays the learning rate of each parameter group by gamma every step_size epochs.
      scheduler.step()

  return loss




Function to visualize training loss of compressedSGD:

In [None]:
def plot_training_loss_cSGD(learning_rates, betas, num_bits, save_fig = False, file='compressedSGD_MNIST.pdf'):
  loss = {}
  for lr in learning_rates:
    for b in betas:
      loss[(lr, b)] = train_model_MNIST(
      num_classes = 10,
      model_id = "1",
      optimizer_choice = "compressedSGD",
      epochs = 1,
      train_batch_size = 64,
      test_batch_size = 64,
      learning_rate = lr,
      use_scheduler = False,
      gamma = 0.7,
      step_size = 1,
      num_bits = num_bits,
      beta = b,
      seed = 1337,
      track_training_loss = True)

  x = np.linspace(0, len(loss[(learning_rates[0], weight_decay[0])]), num = len(loss[(learning_rates[0], weight_decay[0])]))

  fig, axs = plt.subplots(len(learning_rates), len(weight_decay), figsize=(20, 15))

  for i in range(len(learning_rates)):
    for j in range(len(weight_decay)):
      axs[i, j].plot(x, loss[(learning_rates[i], weight_decay[j])])
      axs[i, j].set_ylim([0, 4])
      axs[i, j].title.set_text('learning rate: ' + str(learning_rates[i]) + ', beta: ' +  str(weight_decay[j]))

  plt.tight_layout()
  plt.savefig(file)

learning_rates = [0.01, 0.001, 0.0005, 0.0001, 0.00005]
weight_decay = [1, 0.8, 0.5, 0.3, 0.1]   
plot_training_loss_cSGD(learning_rates, weight_decay, 3, save_fig=True)

Function to visualize training loss of SGD/signSGD

In [None]:
def plot_training_loss_SGD(learning_rates, optimizer_choice = "SGD", save_fig = False, file='_MNIST.pdf'):
  loss = {}
  for lr in learning_rates:
    loss[lr] = train_model_MNIST(
    num_classes = 10,
    model_id = "1",
    optimizer_choice = optimizer_choice,
    epochs = 1,
    train_batch_size = 64,
    test_batch_size = 64,
    learning_rate = lr,
    use_scheduler = False,
    gamma = 0.7,
    step_size = 1,
    seed = 1337,
    track_training_loss = True)

   
  x = np.linspace(0, len(loss[(learning_rates[0])]), num = len(loss[(learning_rates[0])]))

  fig, axs = plt.subplots(len(learning_rates), figsize=(5, 15))

  for i in range(len(learning_rates)):
    axs[i].plot(x, loss[learning_rates[i]])
    axs[i].set_ylim([0, 4])
    axs[i].title.set_text('learning rate: ' + str(learning_rates[i]))

  plt.tight_layout()
  plt.savefig(optimizer_choice + file)

learning_rates = [0.5, 0.1, 0.01, 0.001, 0.0005, 0.0001, 0.00005]
plot_training_loss_SGD(learning_rates, optimizer_choice="SGD", save_fig=True)

In [None]:
train_model_MNIST(
     num_classes = 10,
     model_id = "1",
     optimizer_choice = "SGD",
     save_interval = 25, 
     epochs = 1,
     train_batch_size = 64,
     test_batch_size = 64,
     learning_rate = 0.1,
     use_scheduler = False,
     gamma = 0.7,
     step_size = 1,
     num_bits = 2,
     beta = 0.9,
     seed = 1337,
     track_training_loss = True)