Skip to content

Commit

Permalink
using set_optimizer_grouped_parameters instead
Browse files Browse the repository at this point in the history
  • Loading branch information
Yam0214 committed Feb 13, 2023
1 parent bcbf2ee commit 4a0d661
Showing 1 changed file with 32 additions and 45 deletions.
77 changes: 32 additions & 45 deletions model_zoo/ernie-m/run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@

import numpy as np
import paddle
import paddle.nn as nn
from datasets import load_dataset
from paddle.io import Dataset
from paddle.metric import Accuracy
from paddle.optimizer import AdamW

import paddlenlp
from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.ops.optimizer import layerwise_lr_decay
from paddlenlp.trainer import (
PdArgumentParser,
Trainer,
Expand All @@ -40,7 +37,6 @@
AutoModelForSequenceClassification,
AutoTokenizer,
ErnieMForSequenceClassification,
LinearDecayWithWarmup,
)
from paddlenlp.utils.log import logger

Expand Down Expand Up @@ -150,8 +146,6 @@ def do_train():
training_args.print_config(model_args, "Model")

paddle.set_device(training_args.device)
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()

set_seed(training_args.seed)

Expand Down Expand Up @@ -208,43 +202,6 @@ def collect_all_languages_dataset(split):
model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path, num_labels=num_labels, classifier_dropout=model_args.classifier_dropout
)
n_layers = model.config.num_hidden_layers
if paddle.distributed.get_world_size() > 1:
model = paddle.DataParallel(model)

warmup = training_args.warmup_steps if training_args.warmup_steps > 0 else training_args.warmup_ratio
if training_args.do_train:
num_training_steps = (
training_args.max_steps
if training_args.max_steps > 0
else len(train_ds) // training_args.train_batch_size * training_args.num_train_epochs
)
else:
num_training_steps = 10
lr_scheduler = LinearDecayWithWarmup(training_args.learning_rate, num_training_steps, warmup)

# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])]
# Construct dict
name_dict = dict()
for n, p in model.named_parameters():
name_dict[p.name] = n

simple_lr_setting = partial(layerwise_lr_decay, model_args.layerwise_decay, name_dict, n_layers)

optimizer = AdamW(
learning_rate=lr_scheduler,
beta1=0.9,
beta2=0.999,
epsilon=training_args.adam_epsilon,
parameters=model.parameters(),
weight_decay=training_args.weight_decay,
apply_decay_param_fun=lambda x: x in decay_params,
lr_ratio=simple_lr_setting,
)

criterion = nn.CrossEntropyLoss()

# Define the metrics of tasks.
def compute_metrics(p):
Expand All @@ -262,16 +219,46 @@ def compute_metrics(p):

trainer = Trainer(
model=model,
criterion=criterion,
args=training_args,
data_collator=data_collator,
train_dataset=train_ds if training_args.do_train else None,
eval_dataset=eval_ds if training_args.do_eval else None,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
optimizers=[optimizer, lr_scheduler],
)

def using_layerwise_lr_decay(layerwise_decay, model, training_args):
"""
Generate parameter names needed to perform weight decay.
All bias and LayerNorm parameters are excluded.
"""
# params_list = [{"params": param, "learning_rate": lr * decay_ratio}, ... ]
params_list = []
ratio = 1.0
n_layers = model.config.num_hidden_layers
for name, param in model.named_parameters():
param_to_train = {"params": param, "dygraph_key_name": name}
if any(nd in name for nd in ["bias", "norm"]):
param_to_train["weight_decay"] = 0.0
else:
param_to_train["weight_decay"] = training_args.weight_decay

if "encoder.layers" in name:
idx = name.find("encoder.layers.")
layer = int(name[idx:].split(".")[2])
ratio = layerwise_decay ** (n_layers - layer)
elif "embedding" in name:
ratio = layerwise_decay ** (n_layers + 1)

param_to_train["learning_rate"] = ratio

params_list.append(param_to_train)
return params_list

params_to_train = using_layerwise_lr_decay(model_args.layerwise_decay, model, training_args)

trainer.set_optimizer_grouped_parameters(params_to_train)

checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
Expand Down

0 comments on commit 4a0d661

Please sign in to comment.