Skip to content

Commit

Permalink
pytext fp16 optimizer (facebookresearch#782)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#782

write the optimizer wrapper in pytext supporting mixed precision training without amp

Differential Revision: D16276949

fbshipit-source-id: 64c8f1367f277bc133ab4bb52540ff964d44e971
  • Loading branch information
Yuqing Liu authored and facebook-github-bot committed Jul 23, 2019
1 parent ea233ac commit 4ee8177
Show file tree
Hide file tree
Showing 2 changed files with 791 additions and 6 deletions.
350 changes: 344 additions & 6 deletions pytext/utils/precision.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from contextlib import contextmanager
from sys import stderr

Expand Down Expand Up @@ -85,6 +86,7 @@ def finalize(self) -> bool:
_FP16_ENABLED = False
_OPT_LEVEL = None
_DELAY_UNSCALE = False
_NON_AMP = True


@contextmanager
Expand Down Expand Up @@ -118,7 +120,10 @@ def initialize(model, optimizer):
global _OPT_LEVEL

if _FP16_ENABLED:

_OPT_LEVEL = "O2" if model.SUPPORT_FP16_OPTIMIZER else "O1"
if _NON_AMP:
return model.half(), FP16Optimizer(optimizer)
return amp.initialize(model, optimizer, opt_level=_OPT_LEVEL)
else:
return model, optimizer
Expand All @@ -128,20 +133,29 @@ def backward(optimizer, loss):
if _FP16_ENABLED:
# 1. Use automatic loss scaling to best use fp16 range
# 2. Clear handle's cache of casted parameters
if loss > 0:
with amp.scale_loss(
loss, optimizer, delay_unscale=_DELAY_UNSCALE
) as scaled_loss:
scaled_loss.backward()
if _NON_AMP:
if loss > 0:
optimizer.backward(loss)
else:
loss.backward()
else:
loss.backward()
if loss > 0:
with amp.scale_loss(
loss, optimizer, delay_unscale=_DELAY_UNSCALE
) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
else:
loss.backward()


def clip_grad_norm(model, optimizer, max_clip_norm):
if _FP16_ENABLED:
# Refer: https://nvidia.github.io/apex/advanced.html
if _NON_AMP:
return optimizer.clip_grad_norm(max_clip_norm)

return torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), max_clip_norm
)
Expand Down Expand Up @@ -189,3 +203,327 @@ def pad_length(n):
n = n + 8 - remainder

return n


"""fp16 optimizer wraps torch.optim to support mixed precision training
structure of fp16Optimzier:
property
fp16_optimizer.param_groups ----------> inner_optimizer.param_groups
| |
___ __ |__ __ __ __ __ __ | __ __ __
| fp16 | after backward | fp32 |
zero_grad ----|-> grads --|-----------------|--> grads <--|-- check overflow
loss --->| weights <-|-----------------|-- weights |
model --->|_ __ __ __ __ __| after step |__ __ __ __ __ __ |
usage:
1 optim.zero_grad()
2 for i in range(N):
3 model.forward() ---- fp16 weights
4 [pre_process()] ---- fp16 grads upscale
5 optim.backward() ---- upscaled fp16 grads
6 [post_process()] ---- downscale and float to fp32 grads
7 optim.step() ---- fp32 weights and grads
class FP16_Optimizer:
= Properties:
- inner_optimizer:
= type: Torch.optim
= contents: optimizer in pytext (eg. Adam)
which is initialized with fp16 params already
- param_groups:
= type: list of dictionaries where key is string and value is a list.
= contents: eg. [{'params':[]}]
- temp_fp32_params
= types: same as param_groups
= purpose: to support accumulating grads calculation
= contents: contain the temp fp32 grads from backward()
and will be unscaled and added to inner optimizer
- scaler:
- flags: BOOLEAN
= weights_update_needed: whether need to copy weights from master to model
= grads_update_needed: whether need to copy grads from model to master
= Methods:
- __init__()
- zero_grad
= effects: clear the grads in self.param_groups(fp16)
- backward()
- post_process()
- step(loss)
class DynamicLossScaler:
= properties:
- init_scale: the beginning scale number
- scale_factor: the step length that we use to increase the scale
- scale_window: the upper bound of iterations among which no overflow is triggered
- tolerance: the upper bound of the frequency that overflow happens
- threshold: the minimum value of the scale
- is_overflow
- is_scaled: whether grads are scaled
= Methods:
- check_overflow
- upscale
- unscale
- update_scale
"""


class DynamicLossScaler(object):
def __init__(self, init_scale, scale_factor, scale_window, tolerance, threshold):
self.scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self.tolerance = tolerance
self.threshold = threshold
self._iter = 0
self._last_overflow_iter = 0
self._last_rescale_iter = 0
self._overflows_since_rescale = 0
self.is_overflow = False

def upscale(self, loss):
return loss.float() * self.scale

def unscale(self, grad):
return grad.mul_(1.0 / self.scale)

def check_overflow(self, grad_norm):
if (
grad_norm == float("inf")
or grad_norm == -float("inf")
or grad_norm != grad_norm
):
self.is_overflow = True
else:
self.is_overflow = False

def check_overflow_step(self, model_params):
for group in model_params:
for p in group["params"]:
if p.grad is not None and self._check_overflow_step(p.grad):
return True
return False

def _check_overflow_step(self, grad):
cpu_sum = float(grad.float().sum())
if cpu_sum == float("inf") or cpu_sum == -float("inf") or cpu_sum != cpu_sum:
return True
return False

def update_scale(self):
"""
= effects:
- if last overflow is far from now, it's time to increase scale
- if more overflow happens than we expected, it's time to decrease the scale
"""
self._iter += 1
if self.is_overflow:
self._last_overflow_iter = self._iter
self.scale = max(self.scale / self.scale_factor, 1)
print(
"overflow happens, skip step, new loss scale is {}".format(self.scale)
)
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.scale *= self.scale_factor


class FP16Optimizer(object):
def __init__(
self,
init_optimizer,
init_scale=2.0 ** 15,
scale_factor=2,
scale_window=2000,
tolerance=0.05,
threshold=None,
):
"""
= input: init_optimizer(initialized already), init_scale, scale_factor,
scale_window, tolerance, threshold
= effects: initialize the optimizer and create master and loss scaling tools
= modifies:
- record the reference of model params (fp16)
- change the inner optimizer's params to fp32 with
torch.optim inner method
- initialized the scaler
- initialized state, default
"""
self.inner_optimizer = init_optimizer
self.param_groups = []
for _i, group in enumerate(self.inner_optimizer.param_groups):
fp16_group = {}
for key, value in group.items():
if key == "params":
fp16_param = []
for j, p in enumerate(value):
fp16_param.append(p)
master_p = p.detach().clone().float()
master_p.requires_grad_(True)
group["params"][j] = master_p
# change the state map:
if p in self.inner_optimizer.state:
self.inner_optimizer.state[
master_p
] = self.inner_optimizer.state.pop(p)

fp16_group["params"] = fp16_param
else:
fp16_group[key] = value
self.param_groups.append(fp16_group)
self.loss_scaler = DynamicLossScaler(
init_scale, scale_factor, scale_window, tolerance, threshold
)
self.state = self.inner_optimizer.state
self.weights_update_needed = False
self.grads_update_needed = False

def zero_grad(self):
for group in self.param_groups:
for p in group["params"]:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()

def backward(self, loss):
"""
= input: loss
= effects: do loss scaling and calculate grads
= modifies:
- upscale grads
- call loss.backward()
"""
scaled_loss = self.loss_scaler.upscale(loss)
scaled_loss.backward()
self.grads_update_needed = True

def _clip_grad_norm(self, parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if norm_type == float("inf"):
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1.0 / norm_type)
return total_norm

def clip_grad_norm_(self):
return (
self._clip_grad_norm(self.model_params()),
self._clip_grad_norm(self.master_params()),
)

def clip_grad_norm(self, max_norm):
self._grads_from_model_to_master()
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(), max_norm)
self.loss_scaler.check_overflow(grad_norm)
self.loss_scaler.update_scale()
if not self.loss_scaler.is_overflow:
return grad_norm
else:
print("for overflow debug: overflow happens")

def step(self):
"""
= input: closure
= effects:
- check model grads whether are overflow
- update the grads from model to master
- call inner optimizer's step
- copy back the weights from inner optimizer to model
"""
self.loss_scaler.is_overflow = self.loss_scaler.check_overflow_step(
self.param_groups
)
if not self.loss_scaler.is_overflow:
self._grads_from_model_to_master()
self.inner_optimizer.step()
self.weights_update_needed = True
self._weights_from_master_to_model()

self.loss_scaler.update_scale()

def _grads_from_model_to_master(self):
if self.grads_update_needed:
for i, group in enumerate(self.param_groups):
for j, p in enumerate(group["params"]):
if self.inner_optimizer.param_groups[i]["params"][j].grad is None:
self.inner_optimizer.param_groups[i]["params"][
j
].grad = torch.empty_like(
self.inner_optimizer.param_groups[i]["params"][j]
)
self.inner_optimizer.param_groups[i]["params"][j].grad.copy_(p.grad)
self.loss_scaler.unscale(
self.inner_optimizer.param_groups[i]["params"][j].grad
)
self.grads_update_needed = False

def _weights_from_master_to_model(self):
if self.weights_update_needed:
for i, group in enumerate(self.inner_optimizer.param_groups):
for j, p in enumerate(group["params"]):
self.param_groups[i]["params"][j].data.copy_(p.data)
self.weights_update_needed = False

def state_dict(self):
state_dict = {}
state_dict["loss_scaler"] = self.loss_scaler
state_dict["loss_scale"] = self.loss_scaler.scale
state_dict["overflow"] = self.loss_scaler.overflow
state_dict["param_groups"] = self.param_groups
state_dict["optimizer_state_dict"] = self.inner_optimizer.state_dict()
return state_dict

def load_state_dict(self, state_dict):
self.loss_scaler = state_dict["loss_scaler"]
self.loss_scaler.scale = state_dict["loss_scale"]
self.loss_scaler.overflow = state_dict["overflow"]
self.optimizer.load_state_dict(state_dict["optimizer_state_dict"])
for i, group in state_dict["param_groups"]:
for j, p in group["params"]:
self.param_groups[i]["params"][j].data.copy_(p.data)

def master_params(self):
for group in self.inner_optimizer.param_groups:
for p in group["params"]:
yield p

def model_params(self):
for group in self.param_groups:
for p in group["params"]:
yield p

def finalize(self):
return self.inner_optimizer.finalize()

def __getstate__(self):
raise RuntimeError("FP16_Optimizer should be serialized using state_dict().")

def __setstate__(self, state):
raise RuntimeError(
"FP16_Optimizer should be deserialized using load_state_dict()."
)

def _get_loss_scale(self):
return self.loss_scaler.scale

def _set_loss_scale(self, value):
self.loss_scaler.scale = value

def _get_state(self):
return self.state

def _set_state(self, value):
self.state = value

def _get_param_groups(self):
return self.param_groups

def _set_param_groups(self, value):
self.param_groups = value
Loading

0 comments on commit 4ee8177

Please sign in to comment.