Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed precision training slower than FP32 training #297

Open
miguelvr opened this issue May 9, 2019 · 8 comments
Open

Mixed precision training slower than FP32 training #297

miguelvr opened this issue May 9, 2019 · 8 comments

Comments

@miguelvr
Copy link

miguelvr commented May 9, 2019

I've been doing some experiments on CIFAR10 with ResNets and decided to give APEX AMP a try.

However, I ran into some performance issues:

  1. AMP with pytorch's torch.nn.parallel.DistributedDataParallel was extremely slow.
  2. AMP with apex.parallel.DistributedDataParallel was slower than the default training with torch.nn.DistributedDataParallel (no apex involved). For reference, normal training took about 15 min, while apex AMP training took 21 minutes (90 epochs on CIFAR-10 with ResNet20)

I followed the installation instructions, but I couldn't install the C++ extensions because of my GCC/CUDA version. Does this justify this slowdown?

You can see the code here:
https://github.com/braincreators/octconv/blob/34440209c4b37fb5198f75e4e8c052e92e80e85d/benchmarks/train.py#L1-L498

And run it (2 GPUs):

Without APEX AMP:
python -m torch.distributed.launch --nproc_per_node 2 train.py -c configs/cifar10/resnet20_small.yml --batch-size 128 --lr 0.1

With APEX AMP:
python -m torch.distributed.launch --nproc_per_node 2 train.py -c configs/cifar10/resnet20_small.yml --batch-size 128 --lr 0.1 --mixed-precision

@mcarilli
Copy link
Contributor

mcarilli commented May 9, 2019

If your network or batch size is small, you may be underutilizing the device, in which case there's not much for Amp to accelerate. What kind of GPUs are you using? Also, is Amp slower than normal training within a single process as well?

@miguelvr
Copy link
Author

miguelvr commented May 9, 2019

I ran the tests with 2x GTX 1080 TI and a batch size of 128 (so 64 per device)

I haven't tested with a single device yet. I'll let you know.

@zsef123
Copy link

zsef123 commented May 10, 2019

GTX 1080 TI have low-rate FP16 performance.

If you want to better performance with FP16, then must be using Volta architecture, or RTX series.

Check this topic https://devtalk.nvidia.com/default/topic/1023708/gpu-accelerated-libraries/fp16-support-on-gtx-1060-and-1080/

@miguelvr
Copy link
Author

Alright, I'll test it on a few V100

@mcarilli
Copy link
Contributor

Yes, the 1080Ti was intended for gaming, so it has really low compute throughput for FP16 math. You need a Tensor Core-enabled GPU (Turing or Volta) to get best results with mixed precision.

@patrickpjiang
Copy link

hello, I ran into the same problem when I was trying to run exps on 1x RTX2080, however the performance with O1 is worse than O0, more time cost and more memory consumed.

The compute capablitiy of RTX2080 is 7.5 and I think it should works with amp(see docs https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions ). Anyone knows why?

@patrickpjiang
Copy link

here is my code, and my env is RTX2080 with CUDA10.1

`import os
from datetime import datetime
import argparse
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
from apex import amp

def main():
parser = argparse.ArgumentParser()
parser.add_argument('-g', '--gpus', default=1, type=int,
help='number of gpus per node')
parser.add_argument('--epochs', default=2, type=int, metavar='N',
help='number of total epochs to run')
args = parser.parse_args()
args.gpu=0
train(args)

def train(args):
gpu = args.gpu
torch.manual_seed(0)
model = torchvision.models.vgg19(pretrained=False)
torch.cuda.set_device(gpu)
model.cuda(gpu)
batch_size = 200
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(gpu)
optimizer = torch.optim.SGD(model.parameters(), 1e-4)
# Wrap the model
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

# Data loading code
train_dataset = torchvision.datasets.CIFAR100(
    root='./data',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1,
    pin_memory=True,
    drop_last=True
)

start = datetime.now()
total_step = len(train_loader)
for epoch in range(args.epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        # with torch.autograd.profiler.profile(use_cuda=True) as prof:
        if True:
            model.train()
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
        # print(prof)
        if (i + 1) % 100 == 0 and gpu == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                epoch + 1,
                args.epochs,
                i + 1,
                total_step,
                loss.item())
            )
    print("Training complete in: " + str(datetime.now() - start))

if name == 'main':
main()`

@patrickpjiang
Copy link

I notice there is an “ImportError”, so I resinatall the apex(with another pytorch version 1.4) and meet another problem named "version mismatch", according to this #323 I deleted the some code about "matching version" and finally installed with no warning!

However, when I ran my test code, the traing time is still longer with O1 than O0 while memory cost is indeed slightly decreased, is that normal?

mode memroy time

O0 3855M 26s/epoch

O1 3557M 33s/epoch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants