-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support data_parallel training and ucf101 dataset #4819
Changes from 3 commits
0939bf8
2518c13
f6e637d
bf18cb7
0740cda
bdd14b6
3434f64
2982f07
3d335cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
CUDA_VISIBLE_DEVICES=0,1,2,3 python3.7 -m paddle.distributed.launch --started_port 38989 --log_dir ./mylog.ucf101.frames tsm.py --config=./tsm_ucf101.yaml --use_gpu=True --use_data_parallel=True |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
from model import TSM_ResNet | ||
from config_utils import * | ||
from reader import KineticsReader | ||
from ucf101_reader import UCF101Reader | ||
|
||
logging.root.handlers = [] | ||
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' | ||
|
@@ -65,12 +66,39 @@ def parse_args(): | |
type=int, | ||
default=None, | ||
help='epoch number, 0 for read from config file') | ||
parser.add_argument( | ||
'--use_data_parallel', | ||
type=ast.literal_eval, | ||
default=True, | ||
help='default use data parallel.') | ||
parser.add_argument( | ||
'--model_save_dir', | ||
type=str, | ||
default='./output', | ||
help='default model save in ./output.') | ||
parser.add_argument( | ||
'--checkpoint', | ||
type=str, | ||
default=None, | ||
help='path to resume training based on previous checkpoints. ' | ||
'None for not resuming any checkpoints.') | ||
parser.add_argument( | ||
'--model_path_pre', | ||
type=str, | ||
default='tsm', | ||
help='default model path pre is tsm.') | ||
parser.add_argument( | ||
'--resnet50_dir', | ||
type=str, | ||
default='./ResNet50_pretrained/', | ||
help='default resnet50 dir is ./ResNet50_pretrained/.') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def val(epoch, model, cfg, args): | ||
reader = KineticsReader(mode="valid", cfg=cfg) | ||
reader = UCF101Reader(name="TSM", mode="valid", cfg=cfg) | ||
reader = reader.create_reader() | ||
total_loss = 0.0 | ||
total_acc1 = 0.0 | ||
|
@@ -101,9 +129,9 @@ def val(epoch, model, cfg, args): | |
epoch, batch_id, | ||
avg_loss.numpy()[0], acc_top1.numpy()[0], acc_top5.numpy()[0])) | ||
|
||
print('Finish loss {} , acc1 {} , acc5 {}'.format( | ||
total_loss / total_sample, total_acc1 / total_sample, total_acc5 / | ||
total_sample)) | ||
print('TEST Epoch {}, iter {}, Finish loss {} , acc1 {} , acc5 {}'.format( | ||
epoch, batch_id, total_loss / total_sample, total_acc1 / total_sample, | ||
total_acc5 / total_sample)) | ||
|
||
|
||
def create_optimizer(cfg, params): | ||
|
@@ -132,26 +160,62 @@ def train(args): | |
valid_config = merge_configs(config, 'valid', vars(args)) | ||
print_configs(train_config, 'Train') | ||
|
||
use_data_parallel = False | ||
use_data_parallel = args.use_data_parallel | ||
trainer_count = fluid.dygraph.parallel.Env().nranks | ||
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ | ||
if use_data_parallel else fluid.CUDAPlace(0) | ||
if not args.use_gpu: | ||
place = fluid.CPUPlace() | ||
elif not args.use_data_parallel: | ||
place = fluid.CUDAPlace(0) | ||
else: | ||
#(data_parallel step1/6) | ||
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) | ||
|
||
#load resnet50 pretrain | ||
pre_state_dict = fluid.load_program_state(args.resnet50_dir) | ||
for key in pre_state_dict.keys(): | ||
print('pre_state_dict.key: {}'.format(key)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. print是调试代码?所有参数名打印出来太长了 建议注释或删除 |
||
|
||
with fluid.dygraph.guard(place): | ||
if use_data_parallel: | ||
strategy = fluid.dygraph.parallel.prepare_context() | ||
|
||
#1. init model | ||
video_model = TSM_ResNet("TSM", train_config) | ||
|
||
#2. set resnet50 backbone weights | ||
param_state_dict = {} | ||
model_dict = video_model.state_dict() | ||
for key in model_dict.keys(): | ||
weight_name = model_dict[key].name | ||
if weight_name in pre_state_dict.keys( | ||
) and weight_name != "fc_0.w_0" and weight_name != "fc_0.b_0": | ||
print('succ Load weight: {}, shape: {}'.format( | ||
weight_name, pre_state_dict[weight_name].shape)) | ||
param_state_dict[key] = pre_state_dict[weight_name] | ||
else: | ||
print('fail Load weight: {}'.format(weight_name)) | ||
param_state_dict[key] = model_dict[key] | ||
video_model.set_dict(param_state_dict) | ||
|
||
#3. init optim | ||
optimizer = create_optimizer(train_config.TRAIN, | ||
video_model.parameters()) | ||
if use_data_parallel: | ||
#(data_parallel step2,3/6) | ||
strategy = fluid.dygraph.parallel.prepare_context() | ||
video_model = fluid.dygraph.parallel.DataParallel(video_model, | ||
strategy) | ||
|
||
# 4. load checkpoint | ||
if args.checkpoint: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resume阶段, epoch计数是否对应调整下? |
||
assert os.path.exists(args.checkpoint + ".pdparams"), \ | ||
"Given dir {}.pdparams not exist.".format(args.checkpoint) | ||
assert os.path.exists(args.checkpoint + ".pdopt"), \ | ||
"Given dir {}.pdopt not exist.".format(args.checkpoint) | ||
para_dict, opti_dict = fluid.dygraph.load_dygraph(args.checkpoint) | ||
video_model.set_dict(para_dict) | ||
optimizer.set_dict(opti_dict) | ||
|
||
# 5. reader | ||
bs_denominator = 1 | ||
if args.use_gpu: | ||
# check number of GPUs | ||
gpus = os.getenv("CUDA_VISIBLE_DEVICES", "") | ||
if gpus == "": | ||
pass | ||
|
@@ -168,27 +232,36 @@ def train(args): | |
train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size / | ||
bs_denominator) | ||
|
||
train_reader = KineticsReader(mode="train", cfg=train_config) | ||
train_reader = UCF101Reader(name="TSM", mode="train", cfg=train_config) | ||
|
||
train_reader = train_reader.create_reader() | ||
if use_data_parallel: | ||
#(data_parallel step4/6) | ||
train_reader = fluid.contrib.reader.distributed_batch_reader( | ||
train_reader) | ||
|
||
# 6. train loop | ||
for epoch in range(train_config.TRAIN.epoch): | ||
video_model.train() | ||
total_loss = 0.0 | ||
total_acc1 = 0.0 | ||
total_acc5 = 0.0 | ||
total_sample = 0 | ||
t_last = time.time() | ||
# 6.1 for each batch, call model() , backward(), and minimize() | ||
for batch_id, data in enumerate(train_reader()): | ||
t1 = time.time() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. t1-t5重新命名一下? 比如batch_start_time |
||
x_data = np.array([item[0] for item in data]) | ||
y_data = np.array([item[1] for item in data]).reshape([-1, 1]) | ||
|
||
imgs = to_variable(x_data) | ||
labels = to_variable(y_data) | ||
labels.stop_gradient = True | ||
|
||
t2 = time.time() | ||
outputs = video_model(imgs) | ||
t3 = time.time() | ||
|
||
loss = fluid.layers.cross_entropy( | ||
input=outputs, label=labels, ignore_index=-1) | ||
avg_loss = fluid.layers.mean(loss) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copy avg_loss to a new variable , and output(print) it instead of avg_loss, in avoid to print avg_loss after scale_loss function, which is already divided by the number of cards |
||
|
@@ -198,34 +271,69 @@ def train(args): | |
acc_top5 = fluid.layers.accuracy( | ||
input=outputs, label=labels, k=5) | ||
|
||
current_step_lr = optimizer.current_step_lr() | ||
if use_data_parallel: | ||
#(data_parallel step5/6) | ||
avg_loss = video_model.scale_loss(avg_loss) | ||
avg_loss.backward() | ||
video_model.apply_collective_grads() | ||
else: | ||
avg_loss.backward() | ||
|
||
t4 = time.time() | ||
optimizer.minimize(avg_loss) | ||
video_model.clear_gradients() | ||
t5 = time.time() | ||
|
||
total_loss += avg_loss.numpy()[0] | ||
total_acc1 += acc_top1.numpy()[0] | ||
total_acc5 += acc_top5.numpy()[0] | ||
total_sample += 1 | ||
|
||
print('TRAIN Epoch {}, iter {}, loss = {}, acc1 {}, acc5 {}'. | ||
format(epoch, batch_id, | ||
avg_loss.numpy()[0], | ||
acc_top1.numpy()[0], acc_top5.numpy()[0])) | ||
print( | ||
'TRAIN Epoch: %d, iter: %d, loss: %.5f, acc1: %.5f, acc5: %.5f, lr: %.5f, forward_cost:%.5f s, backward_cost:%.5f s, minimize_cost:%.5f s, to_variable_cost: %.5f s, batch_cost: %.5f s, reader_cost: %.5f s' | ||
% (epoch, batch_id, avg_loss.numpy()[0], | ||
acc_top1.numpy()[0], acc_top5.numpy()[0], | ||
current_step_lr, t3 - t2, t4 - t3, t5 - t4, t2 - t1, | ||
t5 - t_last, t2 - t_last)) | ||
t_last = time.time() | ||
|
||
print( | ||
'TRAIN End, Epoch {}, avg_loss= {}, avg_acc1= {}, avg_acc5= {}'. | ||
'TRAIN End, Epoch {}, avg_loss= {}, avg_acc1= {}, avg_acc5= {}, lr={}'. | ||
format(epoch, total_loss / total_sample, total_acc1 / | ||
total_sample, total_acc5 / total_sample)) | ||
total_sample, total_acc5 / total_sample, | ||
current_step_lr)) | ||
|
||
# 6.2 save checkpoint | ||
save_parameters = (not use_data_parallel) or ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不用走or逻辑?单卡时也可以local_rank==0保存,参考 |
||
use_data_parallel and | ||
fluid.dygraph.parallel.Env().local_rank == 0 | ||
) #(data_parallel step6/6) | ||
if save_parameters: | ||
if not os.path.isdir(args.model_save_dir): | ||
os.makedirs(args.model_save_dir) | ||
model_path = os.path.join( | ||
args.model_save_dir, | ||
args.model_path_pre + "_epoch{}".format(epoch)) | ||
fluid.dygraph.save_dygraph(video_model.state_dict(), model_path) | ||
fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path) | ||
print('save_dygraph End, Epoch {}/{} '.format( | ||
epoch, train_config.TRAIN.epoch)) | ||
|
||
# 6.3 validation | ||
video_model.eval() | ||
val(epoch, video_model, valid_config, args) | ||
|
||
if fluid.dygraph.parallel.Env().local_rank == 0: | ||
fluid.dygraph.save_dygraph(video_model.state_dict(), "final") | ||
# 7. save final model | ||
save_parameters = (not args.use_data_parallel) or ( | ||
args.use_data_parallel and | ||
fluid.dygraph.parallel.Env().local_rank == 0) | ||
if save_parameters: | ||
model_path = os.path.join(args.model_save_dir, | ||
args.model_path_pre + "_final") | ||
fluid.dygraph.save_dygraph(video_model.state_dict(), model_path) | ||
fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path) | ||
|
||
logger.info('[TRAIN] training finished') | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what the meaning of model path pre?