Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bf16 Support to Adam #134

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion bmtrain/loss/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@ def has_inf_nan(g_fp16: torch.Tensor, out: torch.Tensor) -> None:
stream = torch.cuda.current_stream().cuda_stream
C.has_nan_inf_launcher(g_fp16.numel(), g_fp16.data_ptr(), mid.data_ptr(), out.data_ptr(), stream)


def has_inf_nan_bf16(g_bf16: torch.Tensor, out: torch.Tensor) -> None:
assert g_bf16.dtype == torch.bfloat16, "g_bf16 must be a bfloat16 tensor"
assert out.dtype == torch.uint8, "out must be a uint8 tensor"
assert CHECK_INPUT(g_bf16), "g_bf16 must be contiguous and on cuda"
assert CHECK_INPUT(out), "out must be contiguous and on cuda"
mid = torch.zeros(1024, device=out.device, dtype=out.dtype)
stream = torch.cuda.current_stream().cuda_stream
C.has_nan_inf_bf16_launcher(g_bf16.numel(), g_bf16.data_ptr(), mid.data_ptr(), out.data_ptr(), stream)

def cross_entropy_forward(m: int, n: int, input: torch.Tensor, target: torch.Tensor,
softmax: torch.Tensor, output: torch.Tensor, ignore_index: int) -> None:
Expand Down
76 changes: 75 additions & 1 deletion bmtrain/optim/_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from .. import C
import torch
CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda
Expand Down Expand Up @@ -76,3 +75,78 @@ def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tenso
bias_correction2,
stream
)

def adam_cpu_bf16(param_fp32: torch.Tensor, param_bf16: torch.Tensor, g_bf16: torch.Tensor, m_fp32: torch.Tensor,
v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float,
weight_decay: float, step: int) -> None:
assert param_fp32.is_contiguous(), "param_fp32 must be contiguous"
assert param_bf16.is_contiguous(), "param_bf16 must be contiguous"
assert g_bf16.is_contiguous(), "g_bf16 must be contiguous"
assert m_fp32.is_contiguous(), "m_fp32 must be contiguous"
assert v_fp32.is_contiguous(), "v_fp32 must be contiguous"
assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
assert param_bf16.dtype == torch.bfloat16, "param_bf16 must be bfloat16 tensor"
assert g_bf16.dtype == torch.bfloat16, "g_bf16 must be bfloat16 tensor"
assert m_fp32.dtype == torch.float32, "m_fp32 must be float32 tensor"
assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
assert param_fp32.device == torch.device("cpu"), "param_fp32 must be a cpu tensor"
assert param_bf16.device == torch.device("cpu"), "param_bf16 must be a cpu tensor"
assert g_bf16.device == torch.device("cpu"), "g_bf16 must be a cpu tensor"
assert m_fp32.device == torch.device("cpu"), "m_fp32 must be a cpu tensor"
assert v_fp32.device == torch.device("cpu"), "v_fp32 must be a cpu tensor"
assert param_fp32.numel() == param_bf16.numel(), "param_fp32 and param_bf16 must have the same number of elements"
assert param_fp32.numel() == g_bf16.numel(), "param_fp32 and g_bf16 must have the same number of elements"
assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_fp32 must have the same number of elements"
assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements"
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
C.adam_cpu_bf16_launcher(
param_fp32.numel(),
param_fp32.data_ptr(),
param_bf16.data_ptr(),
g_bf16.data_ptr(),
m_fp32.data_ptr(),
v_fp32.data_ptr(),
beta1, beta2,
eps, lr,
scale,
weight_decay,
bias_correction1,
bias_correction2,
)

def adam_bf16(param_fp32: torch.Tensor, param_bf16: torch.Tensor, g_bf16: torch.Tensor, m_fp32: torch.Tensor,
v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float,
weight_decay: float, step: int) -> None:
assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda"
assert CHECK_INPUT(param_bf16), "param_bf16 must be contiguous and on cuda"
assert CHECK_INPUT(g_bf16), "g_bf16 must be contiguous and on cuda"
assert CHECK_INPUT(m_fp32), "m_fp32 must be contiguous and on cuda"
assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda"
assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
assert param_bf16.dtype == torch.bfloat16, "param_fp16 must be float16 tensor"
assert g_bf16.dtype == torch.bfloat16, "g_bf16 must be bfloat16 tensor"
assert m_fp32.dtype == torch.float32, "m_fp32 must be bfloat16 tensor"
assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
assert param_fp32.numel() == param_bf16.numel(), "param_fp32 and param_bf16 must have the same number of elements"
assert param_fp32.numel() == g_bf16.numel(), "param_fp32 and g_fp16 must have the same number of elements"
assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_m_fp32 must have the same number of elements"
assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements"
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
stream = torch.cuda.current_stream().cuda_stream
C.adam_bf16_launcher(
param_fp32.numel(),
param_fp32.data_ptr(),
param_bf16.data_ptr(),
g_bf16.data_ptr(),
m_fp32.data_ptr(),
v_fp32.data_ptr(),
beta1, beta2,
eps, lr,
scale,
weight_decay,
bias_correction1,
bias_correction2,
stream
)
43 changes: 32 additions & 11 deletions bmtrain/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ def _on_justify_scale(self, old_scale, new_scale):
if p in self.state:
state = self.state[p]
if len(state) > 0:
state['exp_avg'] *= delta
state['exp_avg_sq'] *= delta
#if p belongs to bf16, do not justify scale
if p.dtype == torch.bfloat16:
continue
else:
state['exp_avg'] *= delta
state['exp_avg_sq'] *= delta

@torch.no_grad()
def step(self, closure=None, scale=1):
Expand All @@ -63,19 +67,22 @@ def step(self, closure=None, scale=1):
if p.grad is not None and p.requires_grad:
if p.grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
if p.dtype not in [torch.float16, torch.float32]:
raise RuntimeError('Adam only supports fp32 or fp16 gradients')
if p.dtype not in [torch.float32, torch.half, torch.bfloat16]:
raise RuntimeError('Adam only supports fp32, fp16 and bf16 gradients')

state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros(p.size(), dtype=p.dtype, device=p.device) # on device
if p.dtype == torch.bfloat16:
state['exp_avg'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device) # on device
else:
state['exp_avg'] = torch.zeros(p.size(), dtype=p.dtype, device=p.device) # on device
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device) # on device

if p.dtype == torch.half:
state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device)# on device
if p.dtype != torch.float32:
state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device=p.device) # on device
state['_param_fp32'].copy_(p)

Expand All @@ -88,7 +95,7 @@ def step(self, closure=None, scale=1):
grad = p.grad

if p.dtype == torch.half:
F.adam(
F.adam(
state["_param_fp32"], # fp32
p, # fp16
grad, # fp16
Expand All @@ -101,6 +108,20 @@ def step(self, closure=None, scale=1):
group['weight_decay'],
state['step']
)
elif p.dtype == torch.bfloat16:
F.adam_bf16(
state["_param_fp32"], # fp32
p, # bf16
grad, # bf16
state['exp_avg'], # fp32: m
state["exp_avg_sq"], # fp32: v
group['betas'][0], group['betas'][1],
group['eps'],
0.0 if state["step"] <= self._hold_steps else group['lr'],
scale,
group['weight_decay'],
state['step']
)
else:
other_kwargs = {}
if 'maximize' in inspect.signature(torch.optim._functional.adam).parameters:
Expand Down Expand Up @@ -159,11 +180,11 @@ def load_state_dict(self, state_dict: dict) -> None:
if k in id_map:
param = id_map[k]

if param.dtype == torch.half and "_param_fp32" not in v:
if param.dtype != torch.float32 and "_param_fp32" not in v:
v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device=param.device)
v["_param_fp32"].copy_(param)

for name, dtype in [("exp_avg", param.dtype), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]:
for name, dtype in [("exp_avg", torch.float32 if param.dtype == torch.bfloat16 else param.dtype), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]:
if name in v:
v[name] = v[name].to(param.device).to(dtype)

Expand Down
42 changes: 36 additions & 6 deletions bmtrain/optim/adam_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def step(self, closure=None, scale=1):
if p.grad is not None and p.requires_grad:
if p.grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
if p.dtype not in [torch.float16, torch.float32]:
raise RuntimeError('Adam only supports fp32 or fp16 gradients')
if p.dtype not in [torch.float16, torch.float32, torch.bfloat16]:
raise RuntimeError('Adam only supports fp32, fp16 and bf16 gradients')

state = self.state[p]
# Lazy state initialization
Expand All @@ -73,6 +73,15 @@ def step(self, closure=None, scale=1):
# placeholder
state["_param_fp16"] = torch.empty(p.size(), dtype=torch.float16, pin_memory=True) # on host
state["_grad_fp16"] = torch.empty(p.size(), dtype=torch.float16, pin_memory=True) # on host

elif p.dtype == torch.bfloat16:
state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device="cpu") # on host
state['_param_fp32'].copy_(p)

# placeholder
state["_param_bf16"] = torch.empty(p.size(), dtype=torch.bfloat16, pin_memory=True) # on host
state["_grad_bf16"] = torch.empty(p.size(), dtype=torch.bfloat16, pin_memory=True) # on host

else:
state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, pin_memory=True) # on host
state['_param_fp32'].copy_(p)
Expand All @@ -89,6 +98,8 @@ def step(self, closure=None, scale=1):
for param, state, event, _, _, _, _, _ in update_params:
if param.dtype == torch.half:
state["_grad_fp16"].copy_(param.grad, non_blocking=True)
elif param.dtype == torch.bfloat16:
state ["_grad_bf16"].copy_(param.grad, non_blocking=True)
else:
state["_grad_fp32"].copy_(param.grad, non_blocking=True)
torch.cuda.current_stream().record_event(event)
Expand Down Expand Up @@ -119,6 +130,23 @@ def step(self, closure=None, scale=1):
)
# transfer parameters back to device asynchronously
param.copy_(state["_param_fp16"], non_blocking=True)
elif param.dtype == torch.bfloat16:
if ('maximize' in group) and (group['maximize'] is True):
grad = -state["_grad_bf16"]
else:
grad = state["_grad_bf16"]
F.adam_cpu_bf16(
state["_param_fp32"].view(-1),
state["_param_bf16"].view(-1),
grad.view(-1),
state["exp_avg"].view(-1),
state["exp_avg_sq"].view(-1),
beta1, beta2,
eps, 0.0 if state["step"] <= self._hold_steps else lr,
scale,
weight_decay,
state["step"]
)
else:
state["_grad_fp32"].mul_(1.0 / scale)
if ('maximize' in group) and (group['maximize'] is True):
Expand Down Expand Up @@ -197,9 +225,12 @@ def load_state_dict(self, state_dict: dict) -> None:
# initialize placeholders
state[param]["_param_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host
state[param]["_grad_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host
elif param.dtype == torch.bfloat16:
#initialize placeholders
state[param]["_param_bf16"] = torch.empty(param.size(), dtype=torch.bfloat16, pin_memory=True) # on host
state[param]["_grad_bf16"] = torch.empty(param.size(), dtype=torch.bfloat16, pin_memory=True) # on host
else:
state[param]["_param_fp32"] = state[param]["_param_fp32"].pin_memory()

state[param]["_param_fp32"] = state[param]["_param_fp32"].pin_memory() # on host
# initialize placeholders
state[param]["_grad_fp32"] = torch.empty(param.size(), dtype=torch.float32, pin_memory=True) # on host
else:
Expand Down Expand Up @@ -254,5 +285,4 @@ def cut_states(state):

#TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu
def zero_grad(self, set_to_none: bool = False):
super().zero_grad(set_to_none=set_to_none)

super().zero_grad(set_to_none=set_to_none)
7 changes: 4 additions & 3 deletions bmtrain/optim/optim_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, Union, List, Dict, Tuple
import torch
from ..loss._function import has_inf_nan
from ..loss._function import has_inf_nan, has_inf_nan_bf16
from ..utils import print_rank
from ..lr_scheduler.warmup import WarmupLRScheduler
from .. import nccl
Expand All @@ -11,9 +11,10 @@ def check_overflow(param_groups):
has_inf_or_nan = torch.zeros(1, dtype=torch.uint8, device="cuda")[0]
for group in param_groups:
for p in group['params']:
if p.grad is not None and p.dtype == torch.half: # TODO support other types
if p.grad is not None and p.dtype == torch.half:
has_inf_nan(p.grad, has_inf_or_nan)

elif p.grad is not None and p.dtype == torch.bfloat16:
has_inf_nan_bf16(p.grad, has_inf_or_nan) # TODO support other types
if "comm" in config:
nccl.allReduce(has_inf_or_nan.storage(), has_inf_or_nan.storage(), "max", config["comm"])

Expand Down
3 changes: 3 additions & 0 deletions csrc/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

PYBIND11_MODULE(C, m) {
m.def("has_nan_inf_launcher",&has_nan_inf_launcher,"has nan inf");
m.def("has_nan_inf_bf16_launcher",&has_nan_inf_bf16_launcher,"has nan inf bf16");
m.def("adam_launcher", &adam_launcher, "adam function cpu");
m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function cpu");
m.def("adam_cpu_launcher", &adam_cpu_launcher, "adam function cpu");
m.def("adam_cpu_bf16_launcher", &adam_cpu_bf16_launcher, "adam function cpu");
m.def("cross_entropy_forward_launcher", &cross_entropy_forward_launcher, "cross entropy forward");
m.def("cross_entropy_backward_launcher", &cross_entropy_backward_launcher, "cross entropy backward");
m.def("cross_entropy_forward_inplace_launcher", &cross_entropy_forward_inplace_launcher, "cross entropy forward inplace");
Expand Down
64 changes: 63 additions & 1 deletion csrc/cuda/adam_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cstdint>

namespace {
Expand Down Expand Up @@ -33,6 +34,40 @@ __global__ void adam_fp32_accum(
m[col] = __float2half(local_m);
}
}

__global__ void adam_fp32_accum_bf16(
int32_t n,
const nv_bfloat16 *g, // (n)
nv_bfloat16 *m, // (n)
float *v, // (n)
float* param, // (n)
nv_bfloat16* param_h, // (n)
float beta1,
float beta2,
float eps,
float lr,
float scale,
float weight_decay,
float bias_correction1,
float bias_correction2
) {
int32_t col = blockIdx.x * blockDim.x + threadIdx.x;

if (col < n) {
float local_g = __bfloat162float(g[col]) / scale; // real_g
float local_m = beta1 * __bfloat162float(m[col]) + (1 - beta1) * local_g; // real_m
float local_v = beta2 * v[col] + (1 - beta2) * local_g * local_g; // real_v
float local_p = param[col];
local_p = local_p - lr * local_m / bias_correction1 / (sqrtf(local_v / bias_correction2 / scale) + eps) - lr * weight_decay * local_p;

param_h[col] = __float2bfloat16(local_p);
param[col] = local_p;
v[col] = local_v;
m[col] = __float2bfloat16(local_m);
}

}

}

void adam_launcher(
Expand Down Expand Up @@ -60,4 +95,31 @@ void adam_launcher(
dim3 block_size = dim3(threads, 1, 1);
dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
adam_fp32_accum<<<grid_size, block_size, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, g_ptr, m_ptr, v_fp32_ptr, param_fp32_ptr, param_h_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
}
}

void adam_bf16_launcher(
int n,
std::uintptr_t param_fp32,
std::uintptr_t param_bf16,
std::uintptr_t g_bf16,
std::uintptr_t m_fp32,
std::uintptr_t v_fp32,
float beta1, float beta2,
float eps, float lr,
float scale,
float weight_decay,
float bias_correction1,
float bias_correction2,
uintptr_t stream
) {
if (n <= 0) return;
auto g_ptr = reinterpret_cast<nv_bfloat16*>(g_bf16);
auto m_ptr = reinterpret_cast<nv_bfloat16*>(m_fp32);
auto param_h_ptr = reinterpret_cast<nv_bfloat16*>(param_bf16);
auto param_fp32_ptr = reinterpret_cast<float*>(param_fp32);
auto v_fp32_ptr = reinterpret_cast<float*>(v_fp32);
int32_t threads = 1024;
dim3 block_size = dim3(threads, 1, 1);
dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
adam_fp32_accum_bf16<<<grid_size, block_size, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, g_ptr, m_ptr, v_fp32_ptr, param_fp32_ptr, param_h_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
}
Loading