## Distribute a Model Across Multiple GPUs with Pipeline Parallelism


This notebook demos pipeline parallelism added to PyTorch 1.8 using VGG-16 as an example. For more details, see https://pytorch.org/docs/1.8.0/pipeline.html?highlight=pipeline#.

In [1]:
import os
import time
import numpy as np
import pandas as pd
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.hub import load_state_dict_from_url
from typing import Union, List, Dict, Any, cast

import matplotlib.pyplot as plt
from PIL import Image

import json
from tqdm import tqdm

import sys

sys.path.insert(0, "../..") # to include ../helper_evaluate.py etc.
from helper_utils import set_all_seeds, set_deterministic

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

### parameters

In [2]:
random_seed = 47
learning_rate = 0.0001
batch_size = 8
epochs = 10

num_classes = 5

DEVICE = "cuda:0"

save_path = "vgg16_flower.pth"

### setting

In [3]:
set_all_seeds(random_seed)

set_deterministic()

### data

In [4]:
train_transform = transforms.Compose([transforms.RandomResizedCrop(32),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform = transforms.Compose([transforms.RandomResizedCrop((32, 32)),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

image_path = "D:/work/data/Python/flower_data/"
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=train_transform)
test_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=test_transform)

flower_list = train_dataset.class_to_idx
class_dict = dict((val, key) for key, val in flower_list.items())

# dump dict too json file
json_str = json.dumps(class_dict, indent=4)
with open("class_indices.json", "w") as json_file:
    json_file.write(json_str)
#     json.dump(class_dict, json_file, ensure_ascii=False)

num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 0])
print('Using {} dataloader workers every process'.format(num_workers))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

train_num = len(train_dataset)
test_num = len(test_dataset)

print("train images num: ", train_num)
print("test images num: ", test_num)

for images, labels in train_loader:
    print("images shape: ", images.size())
    print("labels shape: ", labels.size())
    break

Using 0 dataloader workers every process
train images num:  3306
test images num:  364
images shape:  torch.Size([8, 3, 32, 32])
labels shape:  torch.Size([8])


In [5]:
device = torch.device("cuda:0")

num_epochs = 2
for epoch in range(epochs):

    for batch_idx, (x, y) in enumerate(train_loader):
        
        print('Epoch:', epoch+1, end='')
        print(' | Batch index:', batch_idx, end='')
        print(' | Batch size:', y.size()[0])
        
        x = x.to(device)
        y = y.to(device)
        break

Epoch: 1 | Batch index: 0 | Batch size: 8
Epoch: 2 | Batch index: 0 | Batch size: 8
Epoch: 3 | Batch index: 0 | Batch size: 8
Epoch: 4 | Batch index: 0 | Batch size: 8
Epoch: 5 | Batch index: 0 | Batch size: 8
Epoch: 6 | Batch index: 0 | Batch size: 8
Epoch: 7 | Batch index: 0 | Batch size: 8
Epoch: 8 | Batch index: 0 | Batch size: 8
Epoch: 9 | Batch index: 0 | Batch size: 8
Epoch: 10 | Batch index: 0 | Batch size: 8


### model

In [6]:
class VGG16(torch.nn.Module):

    def __init__(self, num_classes):
        super().__init__()
        
        # calculate same padding:
        # (w - k + 2*p)/s + 1 = o
        # => p = (s(o-1) - w + k)/2
        
        self.block_1 = torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=3,
                          out_channels=64,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          # (1(32-1)- 32 + 3)/2 = 1
                          padding=1), 
                torch.nn.ReLU(),
                torch.nn.Conv2d(in_channels=64,
                          out_channels=64,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))
        )
        
        self.block_2 = torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=64,
                          out_channels=128,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(in_channels=128,
                          out_channels=128,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))
        )
        
        self.block_3 = torch.nn.Sequential(        
                torch.nn.Conv2d(in_channels=128,
                          out_channels=256,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(in_channels=256,
                          out_channels=256,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                torch.nn.ReLU(),        
                torch.nn.Conv2d(in_channels=256,
                          out_channels=256,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))
        )
        
          
        self.block_4 = torch.nn.Sequential(   
                torch.nn.Conv2d(in_channels=256,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                torch.nn.ReLU(),        
                torch.nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                torch.nn.ReLU(),        
                torch.nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                torch.nn.ReLU(),            
                torch.nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))
        )
        
        self.block_5 = torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                torch.nn.ReLU(),            
                torch.nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                torch.nn.ReLU(),            
                torch.nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                torch.nn.ReLU(),    
                torch.nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))             
        )
            
        self.classifier = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(512, 4096),
            torch.nn.ReLU(True),
            #torch.nn.Dropout(p=0.5),
            torch.nn.Linear(4096, 4096),
            torch.nn.ReLU(True),
            #torch.nn.Dropout(p=0.5),
            torch.nn.Linear(4096, num_classes),
        )
        
        
    def forward(self, x):
        x = self.block_1(x)
        x = self.block_2(x)
        x = self.block_3(x)
        x = self.block_4(x)
        x = self.block_5(x)
        x = self.classifier(x) # logits

        return x

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = VGG16(num_classes=num_classes)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

cuda


### 1) Regular (1-GPU) Training

In [8]:
best_acc = 0.0
# train_steps: train_num // batch_size
train_steps = len(train_loader)
for epoch in range(epochs):
    model.train()
    train_bar = tqdm(train_loader)
    running_loss = 0.0
    for step, (images, labels) in enumerate(train_bar):
        images = images.to(device)
        labels = labels.to(device)
        
        # forward
        logits = model(images)
        
        # backward
        optimizer.zero_grad()
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        
        # update
        optimizer.step()
        
        # logging
        running_loss += loss.item()
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
    
    model.eval()
    correct = 0.0
    with torch.no_grad():
        test_bar = tqdm(test_loader)
        for images, labels in test_bar:
            images = images.to(device)
            labels = labels.to(device)
            
            logits = model(images)
            _, predict_labels = torch.max(logits, dim=1)
            correct += torch.eq(predict_labels, labels).sum().float()
    
    test_acc = correct / test_num
    train_loss = running_loss / train_steps
    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, train_loss, test_acc))
    
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), save_path)
        print("save model pth to %s" % (save_path))
    
print("Finished training")

train epoch[1/10] loss:1.733: 100%|██████████████████████████████████████████████████| 414/414 [01:41<00:00,  4.07it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 46/46 [00:06<00:00,  7.14it/s]


[epoch 1] train_loss: 1.604  val_accuracy: 0.245


train epoch[2/10] loss:1.808:   0%|▏                                                   | 1/414 [00:00<00:49,  8.27it/s]

save model pth to vgg16_flower.pth


train epoch[2/10] loss:1.090: 100%|██████████████████████████████████████████████████| 414/414 [00:48<00:00,  8.55it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 46/46 [00:01<00:00, 27.59it/s]


[epoch 2] train_loss: 1.504  val_accuracy: 0.374


train epoch[3/10] loss:1.493:   0%|▏                                                   | 1/414 [00:00<00:45,  9.11it/s]

save model pth to vgg16_flower.pth


train epoch[3/10] loss:1.736: 100%|██████████████████████████████████████████████████| 414/414 [00:49<00:00,  8.35it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 46/46 [00:01<00:00, 25.10it/s]


[epoch 3] train_loss: 1.311  val_accuracy: 0.415


train epoch[4/10] loss:1.110:   0%|▏                                                   | 1/414 [00:00<00:48,  8.50it/s]

save model pth to vgg16_flower.pth


train epoch[4/10] loss:0.821: 100%|██████████████████████████████████████████████████| 414/414 [00:51<00:00,  8.11it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 46/46 [00:01<00:00, 26.75it/s]


[epoch 4] train_loss: 1.259  val_accuracy: 0.519


  0%|                                                                                          | 0/414 [00:00<?, ?it/s]

save model pth to vgg16_flower.pth


train epoch[5/10] loss:1.611: 100%|██████████████████████████████████████████████████| 414/414 [00:50<00:00,  8.13it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 46/46 [00:01<00:00, 25.28it/s]


[epoch 5] train_loss: 1.243  val_accuracy: 0.533


train epoch[6/10] loss:1.010:   0%|▏                                                   | 1/414 [00:00<00:41,  9.87it/s]

save model pth to vgg16_flower.pth


train epoch[6/10] loss:0.736: 100%|██████████████████████████████████████████████████| 414/414 [00:51<00:00,  8.03it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 46/46 [00:01<00:00, 24.79it/s]
train epoch[7/10] loss:1.179:   0%|▏                                                   | 1/414 [00:00<00:47,  8.66it/s]

[epoch 6] train_loss: 1.206  val_accuracy: 0.497


train epoch[7/10] loss:0.994: 100%|██████████████████████████████████████████████████| 414/414 [00:51<00:00,  8.06it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 46/46 [00:01<00:00, 25.01it/s]


[epoch 7] train_loss: 1.184  val_accuracy: 0.541


train epoch[8/10] loss:0.883:   0%|▏                                                   | 1/414 [00:00<00:52,  7.94it/s]

save model pth to vgg16_flower.pth


train epoch[8/10] loss:0.935: 100%|██████████████████████████████████████████████████| 414/414 [00:51<00:00,  8.01it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 46/46 [00:01<00:00, 23.94it/s]


[epoch 8] train_loss: 1.148  val_accuracy: 0.547


train epoch[9/10] loss:1.033:   0%|▏                                                   | 1/414 [00:00<00:50,  8.20it/s]

save model pth to vgg16_flower.pth


train epoch[9/10] loss:1.333: 100%|██████████████████████████████████████████████████| 414/414 [00:53<00:00,  7.73it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 46/46 [00:01<00:00, 25.48it/s]
train epoch[10/10] loss:0.887:   0%|                                                   | 1/414 [00:00<00:49,  8.30it/s]

[epoch 9] train_loss: 1.106  val_accuracy: 0.486


train epoch[10/10] loss:0.379: 100%|█████████████████████████████████████████████████| 414/414 [00:52<00:00,  7.84it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 46/46 [00:01<00:00, 24.91it/s]


[epoch 10] train_loss: 1.118  val_accuracy: 0.607
save model pth to vgg16_flower.pth
Finished training


### 2) VGG16 with Pipeline Parallelism

In [9]:
block_1 = torch.nn.Sequential(
        torch.nn.Conv2d(in_channels=3,
                  out_channels=64,
                  kernel_size=(3, 3),
                  stride=(1, 1),
                  # (1(32-1)- 32 + 3)/2 = 1
                  padding=1), 
        torch.nn.ReLU(),
        torch.nn.Conv2d(in_channels=64,
                  out_channels=64,
                  kernel_size=(3, 3),
                  stride=(1, 1),
                  padding=1),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(kernel_size=(2, 2),
                     stride=(2, 2))
)

block_2 = torch.nn.Sequential(
        torch.nn.Conv2d(in_channels=64,
                  out_channels=128,
                  kernel_size=(3, 3),
                  stride=(1, 1),
                  padding=1),
        torch.nn.ReLU(),
        torch.nn.Conv2d(in_channels=128,
                  out_channels=128,
                  kernel_size=(3, 3),
                  stride=(1, 1),
                  padding=1),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(kernel_size=(2, 2),
                     stride=(2, 2))
)
        
block_3 = torch.nn.Sequential(        
        torch.nn.Conv2d(in_channels=128,
                  out_channels=256,
                  kernel_size=(3, 3),
                  stride=(1, 1),
                  padding=1),
        torch.nn.ReLU(),
        torch.nn.Conv2d(in_channels=256,
                  out_channels=256,
                  kernel_size=(3, 3),
                  stride=(1, 1),
                  padding=1),
        torch.nn.ReLU(),        
        torch.nn.Conv2d(in_channels=256,
                  out_channels=256,
                  kernel_size=(3, 3),
                  stride=(1, 1),
                  padding=1),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(kernel_size=(2, 2),
                     stride=(2, 2))
)
        
          
block_4 = torch.nn.Sequential(   
        torch.nn.Conv2d(in_channels=256,
                  out_channels=512,
                  kernel_size=(3, 3),
                  stride=(1, 1),
                  padding=1),
        torch.nn.ReLU(),        
        torch.nn.Conv2d(in_channels=512,
                  out_channels=512,
                  kernel_size=(3, 3),
                  stride=(1, 1),
                  padding=1),
        torch.nn.ReLU(),        
        torch.nn.Conv2d(in_channels=512,
                  out_channels=512,
                  kernel_size=(3, 3),
                  stride=(1, 1),
                  padding=1),
        torch.nn.ReLU(),            
        torch.nn.MaxPool2d(kernel_size=(2, 2),
                     stride=(2, 2))
)
        
block_5 = torch.nn.Sequential(
        torch.nn.Conv2d(in_channels=512,
                  out_channels=512,
                  kernel_size=(3, 3),
                  stride=(1, 1),
                  padding=1),
        torch.nn.ReLU(),            
        torch.nn.Conv2d(in_channels=512,
                  out_channels=512,
                  kernel_size=(3, 3),
                  stride=(1, 1),
                  padding=1),
        torch.nn.ReLU(),            
        torch.nn.Conv2d(in_channels=512,
                  out_channels=512,
                  kernel_size=(3, 3),
                  stride=(1, 1),
                  padding=1),
        torch.nn.ReLU(),    
        torch.nn.MaxPool2d(kernel_size=(2, 2),
                     stride=(2, 2))             
)
            
classifier = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(512, 4096),
    torch.nn.ReLU(True),
    #torch.nn.Dropout(p=0.5),
    torch.nn.Linear(4096, 4096),
    torch.nn.ReLU(True),
    #torch.nn.Dropout(p=0.5),
    torch.nn.Linear(4096, num_classes),
)

In [10]:
torch.distributed.is_available()

False

#### 设置环境变量

In [None]:
%env MASTER_ADDR=xxx.xx.xx.xx

In [None]:
%env MASTER_PORT=8891

Set up the RPC if it is not already running (more details at https://pytorch.org/docs/stable/rpc.html):

In [None]:
try:
    torch.distributed.rpc.init_rpc(name='node1', rank=0, world_size=1)
except RuntimeError as e:
    if str(e) == 'Address already in use':
        pass
    else:
        raise RuntimeError(e)

This is the main part for running the model on multiple GPUs.

1. We wrap the individual blocks into a Sequential model
2. The chunks refer to the microbatches, for more details, see https://pytorch.org/docs/1.8.0/pipeline.html?highlight=pipeline#


In [None]:
from torch.distributed.pipeline.sync import Pipe


block1 = block_1.cuda(0)
block2 = block_2.cuda(0)
block3 = block_3.cuda(2)
block4 = block_4.cuda(2)
block4 = block_5.cuda(3)
block4 = classifier.cuda(0)

model_parallel = torch.nn.Sequential(
    block_1, block_2, block_3, block_4, block_5, classifier)
model_parallel = Pipe(model_parallel, chunks=8)
optimizer = torch.optim.Adam(model_parallel.parameters(), lr=learning_rate)

In [None]:
device = torch.device('cuda:0')
print(device)

best_acc = 0.0
# train_steps: train_num // batch_size
train_steps = len(train_loader)
for epoch in range(epochs):
    model.train()
    train_bar = tqdm(train_loader)
    running_loss = 0.0
    for step, (images, labels) in enumerate(train_bar):
        images = images.to(device)
        labels = labels.to(device)
        
        # forward
        logits = model_parallel(images)
        
        # backward
        optimizer.zero_grad()
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        
        # update
        optimizer.step()
        
        # logging
        running_loss += loss.item()
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
    
    model.eval()
    correct = 0.0
    with torch.no_grad():
        test_bar = tqdm(test_loader)
        for images, labels in test_bar:
            images = images.to(device)
            labels = labels.to(device)
            
            logits = model(images)
            _, predict_labels = torch.max(logits, dim=1)
            correct += torch.eq(predict_labels, labels).sum().float()
    
    test_acc = correct / test_num
    train_loss = running_loss / train_steps
    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, train_loss, test_acc))
    
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), save_path)
        print("save model pth to %s" % (save_path))
    
print("Finished training")