Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training with apex ,before calling amp.init() #506

Open
molyswu opened this issue Sep 26, 2019 · 1 comment
Open

Training with apex ,before calling amp.init() #506

molyswu opened this issue Sep 26, 2019 · 1 comment

Comments

@molyswu
Copy link

molyswu commented Sep 26, 2019

Try to distribution training pytorch code with apex for error:
[1,1]:
[1,1]:global [ 1/ 2], local [ 1/ 2]
[1,0]:global [ 0/ 2], local [ 0/ 2]
[1,1]:Selected optimization level O1: Insert automatic casts around Pytorch functions and Tensor methods.
[1,1]:
[1,1]:Defaults for this optimization level are:
[1,1]:loss_scale : dynamic
[1,1]:opt_level : O1
[1,1]:master_weights : None
[1,1]:cast_model_type : None
[1,1]:patch_torch_functions : True
[1,1]:keep_batchnorm_fp32 : None
[1,1]:enabled : True
[1,1]:Processing user overrides (additional kwargs that are not None)...
[1,1]:After processing overrides, optimization options are:
[1,1]:loss_scale : 1.0
[1,1]:opt_level : O1
[1,1]:master_weights : None
[1,1]:cast_model_type : None
[1,1]:patch_torch_functions : True
[1,1]:keep_batchnorm_fp32 : None
[1,1]:enabled : True
[1,0]:Selected optimization level O1: Insert automatic casts around Pytorch functions and Tensor methods.
[1,0]:
[1,0]:Defaults for this optimization level are:
[1,0]:cast_model_type : None
[1,0]:keep_batchnorm_fp32 : None
[1,0]:patch_torch_functions : True
[1,0]:enabled : True
[1,0]:loss_scale : dynamic
[1,0]:master_weights : None
[1,0]:opt_level : O1
[1,0]:Processing user overrides (additional kwargs that are not None)...
[1,0]:After processing overrides, optimization options are:
[1,0]:cast_model_type : None
[1,0]:keep_batchnorm_fp32 : None
[1,0]:patch_torch_functions : True
[1,0]:enabled : True
[1,0]:loss_scale : 1.0
[1,0]:master_weights : None
[1,0]:opt_level : O1
[1,0]:
[1,0]:
[1,0]:epoch: 0,batch: 0
[1,1]:Traceback (most recent call last):
[1,1]: File "train_1.py", line 186, in
[1,1]: loss, outputs = model(imgs, targets)
[1,1]: File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 547, in call
[1,1]: result = self.forward(*input, **kwargs)
[1,1]: File "/media/ai/sdc1/tools/horovod_study/yolov3-pytoch/PyTorch-YOLOv3/models.py", line 260, in forward
[1,1]: x, layer_loss = module[0](x, targets, img_dim)
[1,1]: File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 547, in call
[1,1]: result = self.forward(*input, **kwargs)
[1,1]: File "/media/ai/sdc1/tools/horovod_study/yolov3-pytoch/PyTorch-YOLOv3/models.py", line 197, in forward
[1,1]: loss_conf_obj = self.bce_loss(pred_conf[obj_mask], tconf[obj_mask])
[1,1]: File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 547, in call
[1,1]: result = self.forward(*input, **kwargs)
[1,1]: File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/loss.py", line 498, in forward
[1,1]: return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
[1,1]: File "/usr/local/lib/python3.5/dist-packages/apex/amp/wrap.py", line 124, in wrapper
[1,1]: raise NotImplementedError(custom_err_msg)
[1,1]:NotImplementedError:
[1,1]:amp does not work out-of-the-box with F.binary_cross_entropy or torch.nn.BCELoss. It requires that the output of the previous function be already a FloatTensor.
[1,1]:
[1,1]:Most models have a Sigmoid right before BCELoss. In that case, you can use
[1,1]: torch.nn.BCEWithLogitsLoss
[1,1]:to combine Sigmoid+BCELoss into a single layer that is compatible with amp.
[1,1]:Another option is to add
[1,1]: amp.register_float_function(torch, 'sigmoid')
[1,1]:before calling amp.init().
[1,1]:If you really know what you are doing, you can disable this warning by passing allow_banned=True to amp.init().
[1,0]:Traceback (most recent call last):
[1,0]: File "train_1.py", line 186, in
[1,0]: loss, outputs = model(imgs, targets)
[1,0]: File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 547, in call
[1,0]: result = self.forward(*input, **kwargs)
[1,0]: File "/media/ai/sdc1/tools/horovod_study/yolov3-pytoch/PyTorch-YOLOv3/models.py", line 260, in forward
[1,0]: x, layer_loss = module[0](x, targets, img_dim)
[1,0]: File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 547, in call
[1,0]: result = self.forward(*input, **kwargs)
[1,0]: File "/media/ai/sdc1/tools/horovod_study/yolov3-pytoch/PyTorch-YOLOv3/models.py", line 197, in forward
[1,0]: loss_conf_obj = self.bce_loss(pred_conf[obj_mask], tconf[obj_mask])
[1,0]: File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 547, in call
[1,0]: result = self.forward(*input, **kwargs)
[1,0]: File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/loss.py", line 498, in forward
[1,0]: return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
[1,0]: File "/usr/local/lib/python3.5/dist-packages/apex/amp/wrap.py", line 124, in wrapper
[1,0]: raise NotImplementedError(custom_err_msg)
[1,0]:NotImplementedError:
[1,0]:amp does not work out-of-the-box with F.binary_cross_entropy or torch.nn.BCELoss. It requires that the output of the previous function be already a FloatTensor.
[1,0]:
[1,0]:Most models have a Sigmoid right before BCELoss. In that case, you can use
[1,0]: torch.nn.BCEWithLogitsLoss
[1,0]:to combine Sigmoid+BCELoss into a single layer that is compatible with amp.
[1,0]:Another option is to add
[1,0]: amp.register_float_function(torch, 'sigmoid')
[1,0]:before calling amp.init().
[1,0]:If you really know what you are doing, you can disable this warning by passing allow_banned=True to amp.init().

code:

from future import division

from models import *
from utils.logger import *
from utils.utils import *
from utils.datasets import *
from utils.parse_config import *
from test import evaluate

from terminaltables import AsciiTable

import os
import sys
import math
import time
import datetime
import argparse

import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data.distributed
import horovod.torch as hvd
import tensorboardX
from tqdm import tqdm
from apex import amp

if name == "main":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
parser.add_argument("--batch_size", type=int, default=8, help="size of each image batch")
parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")
parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")
parser.add_argument("--data_config", type=str, default="config/coco.data", help="path to data config file")
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
help='use fp16 compression during allreduce')
parser.add_argument("--pretrained_weights", type=str, help="if specified starts from checkpoint model")
parser.add_argument('--batches-per-allreduce', type=int, default=1,
help='number of batches processed locally before '
'executing allreduce across workers; it multiplies '
'total batch size.')
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")
parser.add_argument("--evaluation_interval", type=int, default=1, help="interval evaluations on validation set")
parser.add_argument("--compute_map", default=False, help="if True computes mAP every tenth batch")
parser.add_argument("--multiscale_training", default=True, help="allow for multi-scale training")
parser.add_argument('--seed', type=int, default=4,metavar='S',
help='random seed')
parser.add_argument('--base-lr',default=1e-3,help="learning rate")
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--momentum', type=float, default=0.9,
help='Adam momentum')
parser.add_argument('--wd', type=float, default=0.00005,
help='weight decay')
opt = parser.parse_args()
#args = parser.parse_args()
APEX = True
#Apex
if APEX:
import apex
from apex import amp

allreduce_batch_size = opt.batch_size * opt.batches_per_allreduce
print(opt)

logger = Logger("logs")
hvd.init()#init horovod
print("global [%2s/%2s], local [%2s/%2s]" % (hvd.rank(), hvd.size(), hvd.local_rank(), hvd.local_size()))
time.sleep(0.1 * (hvd.size() - hvd.rank()))

torch.cuda.set_device(hvd.local_rank())
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
torch.manual_seed(opt.seed)
torch.cuda.manual_seed_all(4)
#use_cuda = not opt.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#if use_cuda:

os.makedirs("output", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)

# Get data configuration
data_config = parse_data_config(opt.data_config)
train_path = data_config["train"]
valid_path = data_config["valid"]
class_names = load_classes(data_config["names"])

# Initiate model
model = Darknet(opt.model_def).to(device)
model.apply(weights_init_normal)

# If specified we start from checkpoint
if opt.pretrained_weights:
    if opt.pretrained_weights.endswith(".pth"):
        model.load_state_dict(torch.load(opt.pretrained_weights))
    else:
        model.load_darknet_weights(opt.pretrained_weights)



# Get dataloader
dataset = ListDataset(train_path, augment=True, multiscale=opt.multiscale_training)

Horovod: use DistributedSampler to partition data among workers. Manually specify

num_replicas=hvd.size() and rank=hvd.rank().

train_sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=hvd.size(), rank=hvd.rank())
dataloader = torch.utils.data.DataLoader(
    dataset,
   batch_size=opt.batch_size,
    shuffle=True,
   num_workers=opt.n_cpu,
    pin_memory=True,
   collate_fn=dataset.collate_fn,
)

Horovod: scale learning rate by the number of GPUs.

#optimizer = optim.SGD(model.parameters(), lr=(opt.base_lr *opt.batches_per_allreduce* hvd.size()) ,momentum=opt.momentum, weight_decay=opt.wd)
optimizer = torch.optim.Adam(model.parameters())

Horovod: (optional) compression algorithm.

#compression = hvd.Compression.fp16 if opt.fp16_allreduce else hvd.Compression.none
FP16_ALLREDUCE = True

Horovod: wrap optimizer with DistributedOptimizer.

optimizer = hvd.DistributedOptimizer(optimizer,named_parameters=model.named_parameters(),
compression = hvd.Compression.fp16 if  APEX and FP16_ALLREDUCE else hvd.Compression.none)

# Apex 
SYNC_BN =True
if APEX and SYNC_BN:
   model =  apex.parallel.convert_syncbn_model(model)
model.to(device)

Horovod: broadcast parameters & optimizer state.

hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
 # Apex
if APEX:
   
   OPT_LEVEL = "O1"
   KEEP_BATCHNORM_FP32 = None
   LOSS_SCALE = True
   model, optimizer = amp.initialize(model, optimizer,
     opt_level=OPT_LEVEL,
     keep_batchnorm_fp32=KEEP_BATCHNORM_FP32,
     loss_scale=LOSS_SCALE
   )


metrics = [
    "grid_size",
    "loss",
    "x",
    "y",
    "w",
    "h",
    "conf",
    "cls",
    "cls_acc",
    "recall50",
    "recall75",
    "precision",
    "conf_obj",
    "conf_noobj",
]

for epoch in range(opt.epochs):
    model.train()
    start_time = time.time()
    for batch_i, (_, imgs, targets) in enumerate(dataloader):
        
        if hvd.rank() == 0:
           print("\n\nepoch: %s,batch: %s" %(epoch,batch_i))
        batches_done = len(dataloader) * epoch + batch_i
        torch.manual_seed(batches_done)
        torch.cuda.manual_seed_all(batches_done)
        imgs = Variable(imgs.to(device))
        targets = Variable(targets.to(device), requires_grad=False)
        loss, outputs = model(imgs, targets)
        if hvd.rank() == 0:
           
           optimizer.zero_grad()
        #loss.backward()

        #if batches_done % opt.gradient_accumulations:
            # Accumulates gradient before each step
           # optimizer.step()
            #optimizer.zero_grad()
        #Apex
        if APEX:
               with  amp.scale_loss(loss,optimizer) as scale_loss:
                     scale_loss.backward()
                     optimizer.synchronize()
        else:
             loss.backward()
        optimizer.step()

Thanks you!

@ptrblck
Copy link
Contributor

ptrblck commented Sep 26, 2019

Hi @molyswu,
the error points to:

:amp does not work out-of-the-box with F.binary_cross_entropy or torch.nn.BCELoss. It requires that the output of the previous function be already a FloatTensor.
[1,1]:
[1,1]:Most models have a Sigmoid right before BCELoss. In that case, you can use
[1,1]: torch.nn.BCEWithLogitsLoss
[1,1]:to combine Sigmoid+BCELoss into a single layer that is compatible with amp.
[1,1]:Another option is to add
[1,1]: amp.register_float_function(torch, 'sigmoid')
[1,1]:before calling amp.init().
[1,1]:If you really know what you are doing, you can disable this warning by passing allow_banned=True to amp.init().

Could you try the suggested fixes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants