Skip to content

Commit

Permalink
[inference] add int8 rotary embedding kernel for smoothquant (hpcaite…
Browse files Browse the repository at this point in the history
…ch#4843)

* [shardformer] fix GPT2DoubleHeadsModel (hpcaitech#4703)

* [hotfix] Fix import error: colossal.kernel without triton installed (hpcaitech#4722)

* [hotfix] remove triton kernels from kernel init

* revise bloom/llama kernel imports for infer

* [shardformer] to fix whisper test failed due to significant accuracy differences. (hpcaitech#4710)

* [shardformer] fix whisper test failed

* [shardformer] fix whisper test failed

* [shardformer] fix whisper test failed

* [shardformer] fix whisper test failed

* [doc] fix llama2 code link (hpcaitech#4726)

* [doc] fix llama2 code link

* [doc] fix llama2 code link

* [doc] fix llama2 code link

* [doc] Add user document for Shardformer (hpcaitech#4702)

* create shardformer doc files

* add docstring for seq-parallel

* update ShardConfig docstring

* add links to llama example

* add outdated massage

* finish introduction & supporting information

* finish 'how shardformer works'

* finish shardformer.md English doc

* fix doctest fail

* add Chinese document

* [format] applied code formatting on changed files in pull request 4726 (hpcaitech#4727)

Co-authored-by: github-actions <github-actions@github.com>

* [doc] add shardformer support matrix/update tensor parallel documents (hpcaitech#4728)

* add compatibility matrix for shardformer doc

* update tp doc

* Optimized some syntax errors in the documentation and code under applications/ (hpcaitech#4127)

Co-authored-by: flybird11111 <1829166702@qq.com>

* [shardformer] update pipeline parallel document (hpcaitech#4725)

* [shardformer] update pipeline parallel document

* [shardformer] update pipeline parallel document

* [shardformer] update pipeline parallel document

* [shardformer] update pipeline parallel document

* [shardformer] update pipeline parallel document

* [shardformer] update pipeline parallel document

* [shardformer] update pipeline parallel document

* [shardformer] update pipeline parallel document

* [legacy] remove deterministic data loader test

* [shardformer] update seq parallel document (hpcaitech#4730)

* update doc of seq parallel

* fix typo

* [example] add gpt2 HybridParallelPlugin example (hpcaitech#4653)

* add gpt2 HybridParallelPlugin example

* update readme and testci

* update test ci

* fix test_ci bug

* update requirements

* add requirements

* update requirements

* add requirement

* rename file

* [doc] polish shardformer doc (hpcaitech#4735)

* arrange position of chapters

* fix typos in seq parallel doc

* [shardformer] add custom policy in hybrid parallel plugin (hpcaitech#4718)

* add custom policy

* update assert

* [example] llama2 add fine-tune example (hpcaitech#4673)

* [shardformer] update shardformer readme

[shardformer] update shardformer readme

[shardformer] update shardformer readme

* [shardformer] update llama2/opt finetune example and shardformer update to llama2

* [shardformer] update llama2/opt finetune example and shardformer update to llama2

* [shardformer] update llama2/opt finetune example and shardformer update to llama2

* [shardformer] change dataset

* [shardformer] change dataset

* [shardformer] fix CI

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

[example] update opt example

[example] resolve comments

fix

fix

* [example] llama2 add finetune example

* [example] llama2 add finetune example

* [example] llama2 add finetune example

* [example] llama2 add finetune example

* fix

* update llama2 example

* update llama2 example

* fix

* update llama2 example

* update llama2 example

* update llama2 example

* update llama2 example

* update llama2 example

* update llama2 example

* Update requirements.txt

* update llama2 example

* update llama2 example

* update llama2 example

* [doc] explaination of loading large pretrained models (hpcaitech#4741)

* [kernel] update triton init hpcaitech#4740 (hpcaitech#4740)

* [legacy] clean up legacy code (hpcaitech#4743)

* [legacy] remove outdated codes of pipeline (hpcaitech#4692)

* [legacy] remove cli of benchmark and update optim (hpcaitech#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (hpcaitech#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (hpcaitech#4696)

* [legacy] clean up utils (hpcaitech#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (hpcaitech#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci

* [format] applied code formatting on changed files in pull request 4743 (hpcaitech#4750)

Co-authored-by: github-actions <github-actions@github.com>

* [misc] update pre-commit and run all files (hpcaitech#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format

* [doc] explain suitable use case for each plugin

* [doc] put individual plugin explanation in front

* [doc] add model examples for each plugin

* [doc] put native colossalai plugins first in description section

* [chat]: update rm, add wandb and fix bugs (hpcaitech#4471)

* feat: modify forward fn of critic and reward model

* feat: modify calc_action_log_probs

* to: add wandb in sft and rm trainer

* feat: update train_sft

* feat: update train_rm

* style: modify type annotation and add warning

* feat: pass tokenizer to ppo trainer

* to: modify trainer base and maker base

* feat: add wandb in ppo trainer

* feat: pass tokenizer to generate

* test: update generate fn tests

* test: update train tests

* fix: remove action_mask

* feat: remove unused code

* fix: fix wrong ignore_index

* fix: fix mock tokenizer

* chore: update requirements

* revert: modify make_experience

* fix: fix inference

* fix: add padding side

* style: modify _on_learn_batch_end

* test: use mock tokenizer

* fix: use bf16 to avoid overflow

* fix: fix workflow

* [chat] fix gemini strategy

* [chat] fix

* sync: update colossalai strategy

* fix: fix args and model dtype

* fix: fix checkpoint test

* fix: fix requirements

* fix: fix missing import and wrong arg

* fix: temporarily skip gemini test in stage 3

* style: apply pre-commit

* fix: temporarily skip gemini test in stage 1&2

---------

Co-authored-by: Mingyan Jiang <1829166702@qq.com>

* [shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (hpcaitech#4758)

* fix master param sync for hybrid plugin

* rewrite unwrap for ddp/fsdp

* rewrite unwrap for zero/gemini

* rewrite unwrap for hybrid plugin

* fix geemini unwrap

* fix bugs

* [bug] fix get_default_parser in examples (hpcaitech#4764)

* [doc] clean up outdated docs (hpcaitech#4765)

* [doc] clean up outdated docs

* [doc] fix linking

* [doc] fix linking

* [doc] add shardformer doc to sidebar (hpcaitech#4768)

* [chat]: add lora merge weights config (hpcaitech#4766)

* feat: modify lora merge weights fn

* feat: add lora merge weights config

* [lazy] support torch 2.0 (hpcaitech#4763)

* [lazy] support _like methods and clamp

* [lazy] pass transformers models

* [lazy] fix device move and requires grad

* [lazy] fix requires grad and refactor api

* [lazy] fix requires grad

* [bug] Fix the version check bug in colossalai run when generating the cmd. (hpcaitech#4713)

* Fix the version check bug in colossalai run when generating the cmd.

* polish code

* [feature] add gptq for inference (hpcaitech#4754)

* [gptq] add gptq kernel (hpcaitech#4416)

* add gptq

* refactor code

* fix tests

* replace auto-gptq

* rname inferance/quant

* refactor test

* add auto-gptq as an option

* reset requirements

* change assert and check auto-gptq

* add import warnings

* change test flash attn version

* remove example

* change requirements of flash_attn

* modify tests

* [skip ci] change requirements-test

* [gptq] faster gptq cuda kernel (hpcaitech#4494)

* [skip ci] add cuda kernels

* add license

* [skip ci] fix max_input_len

* format files & change test size

* [skip ci]

* [gptq] add gptq tensor parallel (hpcaitech#4538)

* add gptq tensor parallel

* add gptq tp

* delete print

* add test gptq check

* add test auto gptq check

* [gptq] combine gptq and kv cache manager (hpcaitech#4706)

* combine gptq and kv cache manager

* add init bits

* delete useless code

* add model path

* delete usless print and update test

* delete usless import

* move option gptq to shard config

* change replace linear to shardformer

* update bloom policy

* delete useless code

* fix import bug and delete uselss code

* change colossalai/gptq to colossalai/quant/gptq

* update import linear for tests

* delete useless code and mv gptq_kernel to kernel directory

* fix triton kernel

* add triton import

* [inference] chatglm2 infer demo (hpcaitech#4724)

* add chatglm2

* add

* gather needed kernels

* fix some bugs

* finish context forward

* finish context stage

* fix

* add

* pause

* add

* fix bugs

* finish chatglm

* fix bug

* change some logic

* fix bugs

* change some logics

* add

* add

* add

* fix

* fix tests

* fix

* [release] update version (hpcaitech#4775)

* [release] update version

* [doc] revert versions

* initial commit: add colossal llama 2 (hpcaitech#4784)

* [feature] ColossalEval: Evaluation Pipeline for LLMs (hpcaitech#4786)

* Add ColossalEval

* Delete evaluate in Chat

---------

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>

* [doc] add llama2 domain-specific solution news (hpcaitech#4789)

* [doc] add llama2 domain-specific solution news

* [fix] fix weekly runing example (hpcaitech#4787)

* [fix] fix weekly runing example

* [fix] fix weekly runing example

* [doc] polish shardformer doc (hpcaitech#4779)

* fix example format in docstring

* polish shardformer doc

* [checkpointio] support unsharded checkpointIO for hybrid parallel (hpcaitech#4774)

* support unsharded saving/loading for model

* support optimizer unsharded saving

* update doc

* support unsharded loading for optimizer

* small fix

* update readme

* [lazy] support from_pretrained (hpcaitech#4801)

* [lazy] patch from pretrained

* [lazy] fix from pretrained and add tests

* [devops] update ci

* update

* [hotfix] change llama2 Colossal-LLaMA-2 script filename (hpcaitech#4800)

change filename:
pretraining.py -> trainin.py
there is no file named pretraing.py. wrong writing

* [misc] add last_epoch in CosineAnnealingWarmupLR (hpcaitech#4778)

* [doc] add lazy init docs (hpcaitech#4808)

* [hotfix] fix norm type error in zero optimizer (hpcaitech#4795)

* [hotfix] Correct several erroneous code comments (hpcaitech#4794)

* [format] applied code formatting on changed files in pull request 4595 (hpcaitech#4602)

Co-authored-by: github-actions <github-actions@github.com>

* fix format (hpcaitech#4815)

* [chat] fix gemini strategy (hpcaitech#4698)

* [chat] fix gemini strategy

* [chat] fix gemini strategy

* [chat] fix gemini strategy

* [chat] fix gemini strategy

* g# This is a combination of 2 commits.

[chat] fix gemini strategy

fox

* [chat] fix gemini strategy

update llama2 example

[chat] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* fix

* fix

* fix

* fix

* fix

* Update train_prompts.py

* Update Qwen-7B results (hpcaitech#4821)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [doc] update slack link (hpcaitech#4823)

* add autotune (hpcaitech#4822)

* update Colossal (hpcaitech#4832)

* add int8 rotary embedding kernel

* remove useless code

---------

Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Pengtai Xu <henryxu880@gmail.com>
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
Co-authored-by: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: littsk <1214689160@qq.com>
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Desperado-Jia <502205863@qq.com>
Co-authored-by: Chandler-Bing <brp12138@163.com>
Co-authored-by: Yan haixu <40758050+hova88@users.noreply.github.com>
  • Loading branch information
22 people committed Oct 13, 2023
1 parent 39f2582 commit 37eb9aa
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 0 deletions.
2 changes: 2 additions & 0 deletions colossalai/kernel/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
from .gptq_triton import gptq_fused_linear_triton
from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd
from .rms_norm import rmsnorm_forward
from .rotary_embedding_kernel import rotary_embedding_fwd
from .softmax import softmax
Expand All @@ -28,4 +29,5 @@
"rotary_embedding_fwd",
"token_attention_fwd",
"gptq_fused_linear_triton",
"int8_rotary_embedding_fwd",
]
119 changes: 119 additions & 0 deletions colossalai/kernel/triton/int8_rotary_embedding_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import torch
import triton
import triton.language as tl


@triton.jit
def _rotary_kernel(
q,
input_scale,
output_scale,
Cos,
Sin,
q_bs_stride,
q_h_stride,
q_d_stride,
cos_bs_stride,
cos_d_stride,
total_len,
HEAD_NUM: tl.constexpr,
BLOCK_HEAD: tl.constexpr,
BLOCK_SEQ: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
current_head_index = tl.program_id(0)
current_seq_index = tl.program_id(1)

dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)

current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)

off_q0 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range0[None, None, :] * q_d_stride
)
off_q1 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range1[None, None, :] * q_d_stride
)

off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride

q0 = tl.load(
q + off_q0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)
q1 = tl.load(
q + off_q1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)

cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
in_scale = tl.load(input_scale)
o_scale = tl.load(output_scale)

q0 = q0.to(tl.float32) * in_scale
q1 = q1.to(tl.float32) * in_scale

out0 = (q0 * cos - q1 * sin) / o_scale
out1 = (q0 * sin + q1 * cos) / o_scale

# out0 = out0.to(tl.int8)
# out1 = out1.to(tl.int8)

tl.store(
q + off_q0,
out0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)
tl.store(
q + off_q1,
out1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)

return


@torch.no_grad()
def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale):
total_len = q.shape[0]
head_num = q.shape[1]
head_dim = q.shape[2]
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
BLOCK_HEAD = 4
BLOCK_SEQ = 32
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
if head_dim >= 128:
num_warps = 8
else:
num_warps = 4

_rotary_kernel[grid](
q,
input_scale,
output_scale,
cos,
sin,
q.stride(0),
q.stride(1),
q.stride(2),
cos.stride(0),
cos.stride(1),
total_len,
HEAD_NUM=head_num,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_SEQ=BLOCK_SEQ,
HEAD_DIM=head_dim,
num_warps=num_warps,
num_stages=1,
)
return
59 changes: 59 additions & 0 deletions tests/test_smoothquant/test_rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Adapted from ModelTC https://github.com/ModelTC/lightllm


import pytest
import torch
from packaging import version

try:
from colossalai.kernel.triton import int8_rotary_embedding_fwd

HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")

TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")


def torch_rotary_emb(x, cos, sin):
seq_len, h, dim = x.shape
x0 = x[:, :, 0 : dim // 2]
x1 = x[:, :, dim // 2 : dim]
cos = cos.view((seq_len, 1, dim // 2))
sin = sin.view((seq_len, 1, dim // 2))
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
return torch.cat((o0, o1), dim=-1)


@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_rotary_emb():
SEQ_LEN = 1
HEAD_NUM = 32
HEAD_DIM = 128
dtype = torch.float
# create data
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
# forward pass
y_torch = torch_rotary_emb(x, cos, sin)

input_scale = torch.max(torch.abs(x)) / 127
output_scale = torch.max(torch.abs(y_torch)) / 127

x = x / input_scale
x = x.to(torch.int8)

int8_rotary_embedding_fwd(x, cos, sin, input_scale, output_scale)
y_triton = x.to(torch.float) * output_scale
assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True)


if __name__ == "__main__":
test_rotary_emb()

0 comments on commit 37eb9aa

Please sign in to comment.