Insert better comments, maybe a plot or two

Move this file, and two simple files
- Put two MPI simple notebooks there, maybe rename as start0_install_mpi_notebook, start1_simple_mpi_notebook,  start2_mnist_mpi_notebook

Create block where you create and time serial network

Then, experiment a bit with 4 or 6 mpi tasks, 16 or 24 layers, 1 backward skip down, and 2 forward, speedup?
- play with anything else, like num channels?

### Apply layer-parallel Torchbraid to simple MNIST problem (fashion or digits)
- See `start0_install_mpi_jupyter`, and `start1_simple_mpi_notebook` for setting up MPI-compatible Jupyter installation

In [1]:
# Connect to local ipython cluster.  Note, the ipcluster profile name must match with the below text. 
# Here, we use 'mpi', but you can name the cluster profile anything
from ipyparallel import Client, error
cluster = Client(profile='mpi')



In [7]:
%%px
# You must upate the sys.path to point to your Torchbraid location
from __future__ import print_function

import statistics as stats
import sys
from timeit import default_timer as timer

import matplotlib.pyplot as pyplot
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from mpi4py import MPI

import sys; sys.path.append("/Users/jacobschroder/joint_repos/torchbraid/torchbraid_py3_10/torchbraid")
import torchbraid
import torchbraid.utils

In [4]:
# Download the data
# Depending on parallel setting, may want to do in parallel with `%%px` command
from torchvision import datasets, transforms
datasets.MNIST('./digit-data', download=True)
datasets.FashionMNIST('./fashion-data', download=True)

Dataset FashionMNIST
    Number of datapoints: 60000
    Root location: ./fashion-data
    Split: Train

In [5]:
%%px
# Open, Close, and Hidden (Step) Layer architectures

class OpenFlatLayer(nn.Module):
  def __init__(self, channels):
    super(OpenFlatLayer, self).__init__()
    self.channels = channels

  def forward(self, x):
    # this bit of python magic simply replicates each image in the batch
    s = len(x.shape) * [1]
    s[1] = self.channels
    x = x.repeat(s)
    return x


class CloseLayer(nn.Module):
  def __init__(self, channels):
    super(CloseLayer, self).__init__()
    self.fc1 = nn.Linear(channels * 28 * 28, 32)
    self.fc2 = nn.Linear(32, 10)

  def forward(self, x):
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = F.relu(x)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)


class StepLayer(nn.Module):
  def __init__(self, channels):
    super(StepLayer, self).__init__()
    ker_width = 3
    self.conv1 = nn.Conv2d(channels, channels, ker_width, padding=1)
    self.conv2 = nn.Conv2d(channels, channels, ker_width, padding=1)

  def forward(self, x):
    return F.relu(self.conv2(F.relu(self.conv1(x))))

In [10]:
%%px
# Serial network class (only used for comparison to parallel network)
class SerialNet(nn.Module):
  def __init__(self, channels=12, local_steps=8, Tf=1.0, serial_nn=None, open_nn=None, close_nn=None):
    super(SerialNet, self).__init__()

    if open_nn is None:
      self.open_nn = OpenFlatLayer(channels)
    else:
      self.open_nn = open_nn

    if serial_nn is None:
      step_layer = lambda: StepLayer(channels)
      numprocs = 1
      parallel_nn = torchbraid.LayerParallel(MPI.COMM_SELF, step_layer, numprocs * local_steps, Tf,
                                             max_fwd_levels=1, max_bwd_levels=1, max_iters=1)
      parallel_nn.setPrintLevel(0, True)
      self.serial_nn = parallel_nn.buildSequentialOnRoot()
    else:
      self.serial_nn = serial_nn

    if close_nn is None:
      self.close_nn = CloseLayer(channels)
    else:
      self.close_nn = close_nn

  def forward(self, x):
    x = self.open_nn(x)
    x = self.serial_nn(x)
    x = self.close_nn(x)
    return x

In [11]:
%%px
# Parallel network class
class ParallelNet(nn.Module):
  def __init__(self, channels=12, local_steps=8, Tf=1.0, max_levels=1, max_iters=1, fwd_max_iters=0, print_level=0,
               braid_print_level=0, cfactor=4, fine_fcf=False, skip_downcycle=True, fmg=False, relax_only_cg=0,
               user_mpi_buf=False, gpu_direct_commu=False):
    super(ParallelNet, self).__init__()

    step_layer = lambda: StepLayer(channels)

    numprocs = MPI.COMM_WORLD.Get_size()

    self.parallel_nn = torchbraid.LayerParallel(MPI.COMM_WORLD, step_layer, local_steps * numprocs, Tf,
                                                max_fwd_levels=max_levels, max_bwd_levels=max_levels,
                                                max_iters=max_iters, user_mpi_buf=user_mpi_buf,
                                                gpu_direct_commu=gpu_direct_commu)
    if fwd_max_iters > 0:
      print('fwd_amx_iters', fwd_max_iters)
      self.parallel_nn.setFwdMaxIters(fwd_max_iters)
    self.parallel_nn.setPrintLevel(print_level, True)
    self.parallel_nn.setPrintLevel(braid_print_level, False)
    self.parallel_nn.setCFactor(cfactor)
    self.parallel_nn.setSkipDowncycle(skip_downcycle)
    self.parallel_nn.setBwdRelaxOnlyCG(relax_only_cg)
    self.parallel_nn.setFwdRelaxOnlyCG(relax_only_cg)

    if fmg:
      self.parallel_nn.setFMG()
    self.parallel_nn.setNumRelax(1)  # FCF elsewehre
    if not fine_fcf:
      self.parallel_nn.setNumRelax(0, level=0)  # F-Relaxation on the fine grid
    else:
      self.parallel_nn.setNumRelax(1, level=0)  # F-Relaxation on the fine grid

    # this object ensures that only the LayerParallel code runs on ranks!=0
    compose = self.compose = self.parallel_nn.comp_op()

    # by passing this through 'compose' (mean composition: e.g. OpenFlatLayer o channels)
    # on processors not equal to 0, these will be None (there are no parameters to train there)
    self.open_nn = compose(OpenFlatLayer, channels)
    self.close_nn = compose(CloseLayer, channels)

  def saveSerialNet(self, name):
    serial_nn = self.parallel_nn.buildSequentialOnRoot()
    if MPI.COMM_WORLD.Get_rank() == 0:
      s_net = SerialNet(-1, -1, -1, serial_nn=serial_nn, open_nn=self.open_nn, close_nn=self.close_nn)
      s_net.eval()
      torch.save(s_net, name)

  def getDiagnostics(self):
    return self.parallel_nn.getDiagnostics()

  def forward(self, x):
    # by passing this through 'o' (mean composition: e.g. self.open_nn o x)
    # this makes sure this is run on only processor 0

    x = self.compose(self.open_nn, x)
    x = self.parallel_nn(x)
    x = self.compose(self.close_nn, x)

    return x

In [47]:
%%px

# Training function for one epoch    
def train(rank, params, model, train_loader, optimizer, epoch, compose, device):
  model.train()
  criterion = nn.CrossEntropyLoss()
  total_time = 0.0
  for batch_idx, (data, target) in enumerate(train_loader):
    start_time = timer()
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = compose(criterion, output, target)
    loss.backward()
    stop_time = timer()
    optimizer.step()

    total_time += stop_time - start_time
    if batch_idx % params['log_interval'] == 0:
      root_print(rank, 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tTime Per Batch {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
               100. * batch_idx / len(train_loader), loss.item(), total_time / (batch_idx + 1.0)))

  root_print(rank, 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tTime Per Batch {:.6f}'.format(
    epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
           100. * (batch_idx + 1) / len(train_loader), loss.item(), total_time / (batch_idx + 1.0)))

# Evaluate model on validation data
def test(rank, model, test_loader, compose, device):
  model.eval()
  test_loss = 0
  correct = 0
  criterion = nn.CrossEntropyLoss()
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += compose(criterion, output, target).item()

      if rank == 0:
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)

  root_print(rank, '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [13]:
%%px

# Parallel printing helper function
def root_print(rank, s):
  if rank == 0:
    print(s)

# Compute number of parallel in time levels
def compute_levels(num_steps, min_coarse_size, cfactor):
  from math import log, floor

  # we want to find $L$ such that ( max_L min_coarse_size*cfactor**L <= num_steps)
  levels = floor(log(float(num_steps) / min_coarse_size, cfactor)) + 1

  if levels < 1:
    levels = 1
  return levels

In [49]:
%%px
#Set default parameters for network and layer-parallel 
params = {}
params['seed'] = 1           # random seed
params['log_interval'] = 10  # how many batches to wait before logging training status
params['dataset'] = 'digits' # 'digits' or 'fashion' MNIST 
params['serial_file'] = None # load starting network from file
#
params['steps'] = 16         # number of times steps in the resnet layer-parallel part
params['channels'] = 4       # number of channels in resnet layer
params['tf'] = 1.0           # final time for resnet layer-parallel part
#
params['percent_data'] = 1.0 # how much of the data to read in and use for training/testing
params['batch_size'] = 50    # input batch size for training
params['epochs'] = 2         # number of epochs to train
params['lr'] = 0.01          # learning rate
#
params['force_lp'] = False   # use layer parallel even if there is only 1 MPI rank
params['lp_levels'] = 3      # max number layer parallel levels 
params['lp_iters'] = 2       # layer parallel iterations
params['lp_fwd_iters'] = -1  # layer parallel (forward) iterations, if -1 use lp-iters
params['lp_print'] = 0       # layer parallel internal print level: 0, 1, 2, 3 
params['lp_braid_print'] = 0 # layer parallel braid print level: 0, 1, 2, 3 
params['lp_cfactor'] = 4     # layer parallel coarsening factor
params['lp_finefcf'] = False # layer parallel fine FCF on or off 
params['no_cuda'] = False    # disables CUDA training
params['warm_up'] = False    # warm up for GPU timings
params['lp_gpu_direct_commu'] = False # layer parallel GPU direct communication
params['lp_user_mpi_buf'] = False     # layer parallel use user-defined mpi buffers 
params['lp_use_downcycle']= False     # layer parallel use downcycle on or off

In [50]:
%%px
# something

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
procs = comm.Get_size()
    
use_cuda = not params['no_cuda'] and torch.cuda.is_available()

device, host = torchbraid.utils.getDevice(comm=comm)
if not use_cuda:
  device = torch.device("cuda" if use_cuda else "cpu")
print(f'Run info rank: {rank}: Torch version: {torch.__version__} | Device: {device} | Host: {host}')

# some logic to default to Serial if on one processor,
# can be overriden by the user to run layer-parallel
if params['force_lp']:
  force_lp = True
elif procs > 1:
  force_lp = True
else:
  force_lp = False

torch.manual_seed(params['seed'])

if params['lp_levels'] == -1:
  min_coarse_size = 3
  params['lp_levels'] = compute_levels(params['steps'], min_coarse_size, params['lp_cfactor'])

local_steps = int(params['steps'] / procs)
if params['steps'] % procs != 0:
  root_print(rank, 'Steps must be an even multiple of the number of processors: %d %d' % (params['steps'], procs))
  sys.exit(0)

root_print(rank, 'MNIST ODENet:')

[stdout:0] No GPUs to be used, CPU only
Run info rank: 0: Torch version: 1.12.1 | Device: cpu | Host: cpu
MNIST ODENet:


[stdout:1] Run info rank: 1: Torch version: 1.12.1 | Device: cpu | Host: cpu


In [51]:
%%px
# read in Digits MNIST or Fashion MNIST
if params['dataset']:
  root_print(rank, '-- Using Digit MNIST')
  transform = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))
                                  ])
  dataset = datasets.MNIST('./digit-data', download=False, transform=transform)
else:
  root_print(rank, '-- Using Fashion MNIST')
  transform = transforms.Compose([transforms.ToTensor()])
  dataset = datasets.FashionMNIST('./fashion-data', download=False, transform=transform)
# if params['digits']

root_print(rank, '-- procs    = {}\n'
                 '-- channels = {}\n'
                 '-- tf       = {}\n'
                 '-- steps    = {}'.format(procs, params['channels'], params['tf'], params['steps']))

train_size = int(50000 * params['percent_data'])
test_size = int(10000 * params['percent_data'])
train_set = torch.utils.data.Subset(dataset, range(train_size))
test_set = torch.utils.data.Subset(dataset, range(train_size, train_size + test_size))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=params['batch_size'], shuffle=False)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=params['batch_size'], shuffle=False)

root_print(rank, '')

[stdout:0] -- Using Digit MNIST
-- procs    = 2
-- channels = 4
-- tf       = 1.0
-- steps    = 16



In [52]:
%%px
if force_lp:
  root_print(rank, 'Using ParallelNet:')
  root_print(rank, '-- max_levels     = {}\n'
                   '-- max_iters      = {}\n'
                   '-- fwd_iters      = {}\n'
                   '-- cfactor        = {}\n'
                   '-- fine fcf       = {}\n'
                   '-- skip down      = {}\n'.format(params['lp_levels'],
                                                     params['lp_iters'],
                                                     params['lp_fwd_iters'],
                                                     params['lp_cfactor'],
                                                     params['lp_finefcf'],
                                                     not params['lp_use_downcycle'] ))
  model = ParallelNet(channels=params['channels'],
                      local_steps=local_steps,
                      max_levels=params['lp_levels'],
                      max_iters=params['lp_iters'],
                      fwd_max_iters=params['lp_fwd_iters'],
                      print_level=params['lp_print'],
                      braid_print_level=params['lp_braid_print'],
                      cfactor=params['lp_cfactor'],
                      fine_fcf=params['lp_finefcf'],
                      skip_downcycle=not params['lp_use_downcycle'],
                      fmg=False, 
                      Tf=params['tf'],
                      relax_only_cg=False,
                      user_mpi_buf=params['lp_user_mpi_buf'],
                      gpu_direct_commu=params['lp_gpu_direct_commu']).to(device)

  if params['serial_file'] is not None:
    model.saveSerialNet(params['serial_file'])
  compose = model.compose

  model.parallel_nn.fwd_app.setTimerFile(
    'b_fwd_s_%d_c_%d_bs_%d_p_%d_gpuc_%d'%(params['steps'], params['channels'], params['batch_size'], procs, params['lp_gpu_direct_commu']) )
  model.parallel_nn.bwd_app.setTimerFile( 
    'b_bwd_s_%d_c_%d_bs_%d_p_%d_gpuc_%d'%(params['steps'], params['channels'], params['batch_size'], procs, params['lp_gpu_direct_commu']) )   
else:
  root_print(rank, 'Using SerialNet:')
  root_print(rank, '-- serial file = {}\n'.format(params['serial_file']))
  if params['serial_file'] is not None:
    print('loading model')
    model = torch.load(params['serial_file'])
  else:
    model = SerialNet(channels=params['channels'], local_steps=local_steps, Tf=params['tf']).to(device)
  compose = lambda op, *p: op(*p)

[stdout:0] Using ParallelNet:
-- max_levels     = 3
-- max_iters      = 2
-- fwd_iters      = -1
-- cfactor        = 4
-- fine fcf       = False
-- skip down      = True



In [53]:
%%px
optimizer = optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9)

epoch_times = []
test_times = []

# check out the initial conditions
# if force_lp:
#   diagnose(rank, model, test_loader,0)

if params['warm_up']:
  warm_up_timer = timer()
  train(rank=rank, args=params, model=model, train_loader=train_loader, optimizer=optimizer, epoch=0,
        compose=compose, device=device)
  if force_lp:
    model.parallel_nn.timer_manager.resetTimers()
    model.parallel_nn.fwd_app.resetBraidTimer()
    model.parallel_nn.bwd_app.resetBraidTimer()
  if use_cuda:
    torch.cuda.synchronize()
  epoch_times = []
  test_times = []
  root_print(rank, f'Warm up timer {timer() - warm_up_timer}')

In [54]:
%%px
for epoch in range(1, params['epochs'] + 1):
  start_time = timer()
  train(rank=rank, params=params, model=model, train_loader=train_loader, optimizer=optimizer, epoch=epoch,
        compose=compose, device=device)
  end_time = timer()
  epoch_times += [end_time - start_time]

  start_time = timer()

  test(rank=rank, model=model, test_loader=test_loader, compose=compose, device=device)
  end_time = timer()
  test_times += [end_time - start_time]

  # print out some diagnostics
  # if force_lp:
  #  diagnose(rank, model, test_loader,epoch)

if force_lp:
  timer_str = model.parallel_nn.getTimersString()
  root_print(rank, timer_str)

root_print(rank,
           f'TIME PER EPOCH: {"{:.2f}".format(stats.mean(epoch_times))} '
           f'{("(1 std dev " + "{:.2f}".format(stats.mean(epoch_times))) if len(epoch_times) > 1 else ""}')
root_print(rank,
           f'TIME PER TEST:  {"{:.2f}".format(stats.mean(test_times))} '
           f'{("(1 std dev " + "{:.2f}".format(stats.mean(test_times))) if len(test_times) > 1 else ""}')



%px:   0%|          | 0/2 [00:00<?, ?tasks/s]

Received Keyboard Interrupt. Sending signal SIGINT to engines...


In [None]:
%%px

# Move all diagnose stuff here...

def diagnose(rank, model, test_loader, epoch):
  model.parallel_nn.diagnostics(True)
  model.eval()
  test_loss = 0
  correct = 0
  criterion = nn.CrossEntropyLoss()

  itr = iter(test_loader)
  data, target = next(itr)

  # compute the model and print out the diagnostic information
  with torch.no_grad():
    output = model(data)

  diagnostic = model.getDiagnostics()

  if rank != 0:
    return

  features = np.array([diagnostic['step_in'][0]] + diagnostic['step_out'])
  params = np.array(diagnostic['params'])

  fig, axs = pyplot.subplots(2, 1)
  axs[0].plot(range(len(features)), features)
  axs[0].set_ylabel('Feature Norm')

  coords = [0.5 + i for i in range(len(features) - 1)]
  axs[1].set_xlim([0, len(features) - 1])
  axs[1].plot(coords, params, '*')
  axs[1].set_ylabel('Parameter Norms: {}/tstep'.format(params.shape[1]))
  axs[1].set_xlabel('Time Step')

  fig.suptitle('Values in Epoch {}'.format(epoch))

  # pyplot.show()
  pyplot.savefig('diagnose{:03d}.png'.format(epoch))