Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support data_parallel training and ucf101 dataset #4819

Merged
merged 9 commits into from
Sep 1, 2020
61 changes: 46 additions & 15 deletions dygraph/tsm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import time
import sys
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
import math
Expand All @@ -28,7 +29,8 @@ def __init__(self,
filter_size,
stride=1,
groups=1,
act=None):
act=None,
name=None):
super(ConvBNLayer, self).__init__()

self._conv = Conv2D(
Expand All @@ -39,14 +41,22 @@ def __init__(self,
padding=(filter_size - 1) // 2,
groups=None,
act=None,
param_attr=fluid.param_attr.ParamAttr(),
param_attr=fluid.param_attr.ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]

self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=fluid.param_attr.ParamAttr(),
bias_attr=fluid.param_attr.ParamAttr())
param_attr=ParamAttr(
name=bn_name + "_scale"), #fluid.param_attr.ParamAttr(),
bias_attr=ParamAttr(bn_name +
"_offset"), #fluid.param_attr.ParamAttr())
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance")

def forward(self, inputs):
y = self._conv(inputs)
Expand All @@ -61,32 +71,36 @@ def __init__(self,
num_filters,
stride,
shortcut=True,
seg_num=8):
seg_num=8,
name=None):
super(BottleneckBlock, self).__init__()

self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu')
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
act='relu',
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None)
act=None,
name=name + "_branch2c")

if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride)
stride=stride,
name=name + "_branch1")
self.shortcut = shortcut
self.seg_num = seg_num
self._num_channels_out = int(num_filters * 4)
Expand Down Expand Up @@ -119,7 +133,12 @@ def __init__(self, name_scope, config):
num_filters = [64, 128, 256, 512]

self.conv = ConvBNLayer(
num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu')
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act='relu',
name="conv1")
self.pool2d_max = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')

Expand All @@ -129,14 +148,23 @@ def __init__(self, name_scope, config):
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
if self.layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)

bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
conv_name,
BottleneckBlock(
num_channels=num_channels,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
seg_num=self.seg_num))
seg_num=self.seg_num,
name=conv_name))
num_channels = int(bottleneck_block._num_channels_out)
self.bottleneck_block_list.append(bottleneck_block)
shortcut = True
Expand All @@ -151,9 +179,12 @@ def __init__(self, name_scope, config):
self.class_dim,
act="softmax",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)),
initializer=fluid.initializer.Uniform(-stdv, stdv),
name="fc_0.w_0"),
bias_attr=fluid.param_attr.ParamAttr(
learning_rate=2.0, regularizer=fluid.regularizer.L2Decay(0.)))
learning_rate=2.0,
regularizer=fluid.regularizer.L2Decay(0.),
name="fc_0.b_0"))

def forward(self, inputs):
y = fluid.layers.reshape(
Expand Down
1 change: 1 addition & 0 deletions dygraph/tsm/run_ucf101.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CUDA_VISIBLE_DEVICES=0,1,2,3 python3.7 -m paddle.distributed.launch --started_port 38989 --log_dir ./mylog.ucf101.frames tsm.py --config=./tsm_ucf101.yaml --use_gpu=True --use_data_parallel=True
148 changes: 128 additions & 20 deletions dygraph/tsm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from model import TSM_ResNet
from config_utils import *
from reader import KineticsReader
from ucf101_reader import UCF101Reader

logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
Expand Down Expand Up @@ -65,12 +66,39 @@ def parse_args():
type=int,
default=None,
help='epoch number, 0 for read from config file')
parser.add_argument(
'--use_data_parallel',
type=ast.literal_eval,
default=True,
help='default use data parallel.')
parser.add_argument(
'--model_save_dir',
type=str,
default='./output',
help='default model save in ./output.')
parser.add_argument(
'--checkpoint',
type=str,
default=None,
help='path to resume training based on previous checkpoints. '
'None for not resuming any checkpoints.')
parser.add_argument(
'--model_path_pre',
type=str,
default='tsm',
help='default model path pre is tsm.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what the meaning of model path pre?

parser.add_argument(
'--resnet50_dir',
type=str,
default='./ResNet50_pretrained/',
help='default resnet50 dir is ./ResNet50_pretrained/.')

args = parser.parse_args()
return args


def val(epoch, model, cfg, args):
reader = KineticsReader(mode="valid", cfg=cfg)
reader = UCF101Reader(name="TSM", mode="valid", cfg=cfg)
reader = reader.create_reader()
total_loss = 0.0
total_acc1 = 0.0
Expand Down Expand Up @@ -101,9 +129,9 @@ def val(epoch, model, cfg, args):
epoch, batch_id,
avg_loss.numpy()[0], acc_top1.numpy()[0], acc_top5.numpy()[0]))

print('Finish loss {} , acc1 {} , acc5 {}'.format(
total_loss / total_sample, total_acc1 / total_sample, total_acc5 /
total_sample))
print('TEST Epoch {}, iter {}, Finish loss {} , acc1 {} , acc5 {}'.format(
epoch, batch_id, total_loss / total_sample, total_acc1 / total_sample,
total_acc5 / total_sample))


def create_optimizer(cfg, params):
Expand Down Expand Up @@ -132,26 +160,62 @@ def train(args):
valid_config = merge_configs(config, 'valid', vars(args))
print_configs(train_config, 'Train')

use_data_parallel = False
use_data_parallel = args.use_data_parallel
trainer_count = fluid.dygraph.parallel.Env().nranks
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if use_data_parallel else fluid.CUDAPlace(0)
if not args.use_gpu:
place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
#(data_parallel step1/6)
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)

#load resnet50 pretrain
pre_state_dict = fluid.load_program_state(args.resnet50_dir)
for key in pre_state_dict.keys():
print('pre_state_dict.key: {}'.format(key))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print是调试代码?所有参数名打印出来太长了 建议注释或删除


with fluid.dygraph.guard(place):
if use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()

#1. init model
video_model = TSM_ResNet("TSM", train_config)

#2. set resnet50 backbone weights
param_state_dict = {}
model_dict = video_model.state_dict()
for key in model_dict.keys():
weight_name = model_dict[key].name
if weight_name in pre_state_dict.keys(
) and weight_name != "fc_0.w_0" and weight_name != "fc_0.b_0":
print('succ Load weight: {}, shape: {}'.format(
weight_name, pre_state_dict[weight_name].shape))
param_state_dict[key] = pre_state_dict[weight_name]
else:
print('fail Load weight: {}'.format(weight_name))
param_state_dict[key] = model_dict[key]
video_model.set_dict(param_state_dict)

#3. init optim
optimizer = create_optimizer(train_config.TRAIN,
video_model.parameters())
if use_data_parallel:
#(data_parallel step2,3/6)
strategy = fluid.dygraph.parallel.prepare_context()
video_model = fluid.dygraph.parallel.DataParallel(video_model,
strategy)

# 4. load checkpoint
if args.checkpoint:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resume阶段, epoch计数是否对应调整下?

assert os.path.exists(args.checkpoint + ".pdparams"), \
"Given dir {}.pdparams not exist.".format(args.checkpoint)
assert os.path.exists(args.checkpoint + ".pdopt"), \
"Given dir {}.pdopt not exist.".format(args.checkpoint)
para_dict, opti_dict = fluid.dygraph.load_dygraph(args.checkpoint)
video_model.set_dict(para_dict)
optimizer.set_dict(opti_dict)

# 5. reader
bs_denominator = 1
if args.use_gpu:
# check number of GPUs
gpus = os.getenv("CUDA_VISIBLE_DEVICES", "")
if gpus == "":
pass
Expand All @@ -168,27 +232,36 @@ def train(args):
train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size /
bs_denominator)

train_reader = KineticsReader(mode="train", cfg=train_config)
train_reader = UCF101Reader(name="TSM", mode="train", cfg=train_config)

train_reader = train_reader.create_reader()
if use_data_parallel:
#(data_parallel step4/6)
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)

# 6. train loop
for epoch in range(train_config.TRAIN.epoch):
video_model.train()
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
t_last = time.time()
# 6.1 for each batch, call model() , backward(), and minimize()
for batch_id, data in enumerate(train_reader()):
t1 = time.time()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t1-t5重新命名一下? 比如batch_start_time

x_data = np.array([item[0] for item in data])
y_data = np.array([item[1] for item in data]).reshape([-1, 1])

imgs = to_variable(x_data)
labels = to_variable(y_data)
labels.stop_gradient = True

t2 = time.time()
outputs = video_model(imgs)
t3 = time.time()

loss = fluid.layers.cross_entropy(
input=outputs, label=labels, ignore_index=-1)
avg_loss = fluid.layers.mean(loss)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy avg_loss to a new variable , and output(print) it instead of avg_loss, in avoid to print avg_loss after scale_loss function, which is already divided by the number of cards

Expand All @@ -198,34 +271,69 @@ def train(args):
acc_top5 = fluid.layers.accuracy(
input=outputs, label=labels, k=5)

current_step_lr = optimizer.current_step_lr()
if use_data_parallel:
#(data_parallel step5/6)
avg_loss = video_model.scale_loss(avg_loss)
avg_loss.backward()
video_model.apply_collective_grads()
else:
avg_loss.backward()

t4 = time.time()
optimizer.minimize(avg_loss)
video_model.clear_gradients()
t5 = time.time()

total_loss += avg_loss.numpy()[0]
total_acc1 += acc_top1.numpy()[0]
total_acc5 += acc_top5.numpy()[0]
total_sample += 1

print('TRAIN Epoch {}, iter {}, loss = {}, acc1 {}, acc5 {}'.
format(epoch, batch_id,
avg_loss.numpy()[0],
acc_top1.numpy()[0], acc_top5.numpy()[0]))
print(
'TRAIN Epoch: %d, iter: %d, loss: %.5f, acc1: %.5f, acc5: %.5f, lr: %.5f, forward_cost:%.5f s, backward_cost:%.5f s, minimize_cost:%.5f s, to_variable_cost: %.5f s, batch_cost: %.5f s, reader_cost: %.5f s'
% (epoch, batch_id, avg_loss.numpy()[0],
acc_top1.numpy()[0], acc_top5.numpy()[0],
current_step_lr, t3 - t2, t4 - t3, t5 - t4, t2 - t1,
t5 - t_last, t2 - t_last))
t_last = time.time()

print(
'TRAIN End, Epoch {}, avg_loss= {}, avg_acc1= {}, avg_acc5= {}'.
'TRAIN End, Epoch {}, avg_loss= {}, avg_acc1= {}, avg_acc5= {}, lr={}'.
format(epoch, total_loss / total_sample, total_acc1 /
total_sample, total_acc5 / total_sample))
total_sample, total_acc5 / total_sample,
current_step_lr))

# 6.2 save checkpoint
save_parameters = (not use_data_parallel) or (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0
) #(data_parallel step6/6)
if save_parameters:
if not os.path.isdir(args.model_save_dir):
os.makedirs(args.model_save_dir)
model_path = os.path.join(
args.model_save_dir,
args.model_path_pre + "_epoch{}".format(epoch))
fluid.dygraph.save_dygraph(video_model.state_dict(), model_path)
fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path)
print('save_dygraph End, Epoch {}/{} '.format(
epoch, train_config.TRAIN.epoch))

# 6.3 validation
video_model.eval()
val(epoch, video_model, valid_config, args)

if fluid.dygraph.parallel.Env().local_rank == 0:
fluid.dygraph.save_dygraph(video_model.state_dict(), "final")
# 7. save final model
save_parameters = (not args.use_data_parallel) or (
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters:
model_path = os.path.join(args.model_save_dir,
args.model_path_pre + "_final")
fluid.dygraph.save_dygraph(video_model.state_dict(), model_path)
fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path)

logger.info('[TRAIN] training finished')


Expand Down
Loading