forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[booster] init module structure and definition (hpcaitech#3056)
- Loading branch information
1 parent
faa8526
commit 4bd3ebc
Showing
11 changed files
with
600 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,316 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
||
import json | ||
import math | ||
import time | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import List, Optional, Tuple | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
from colossalai.context.parallel_mode import ParallelMode | ||
from colossalai.core import global_context as gpc | ||
from colossalai.nn.layer import Embedding1D, Linear1D_Col, Linear1D_Row | ||
|
||
from .llama_tokenizer import LLaMATokenizer | ||
|
||
|
||
@dataclass | ||
class LLaMAConfig: | ||
dim: int = 512 | ||
n_layers: int = 8 | ||
n_heads: int = 8 | ||
vocab_size: int = -1 # defined later by tokenizer | ||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 | ||
norm_eps: float = 1e-5 | ||
|
||
max_batch_size: int = 32 | ||
max_seq_len: int = 1024 | ||
|
||
|
||
class RMSNorm(nn.Module): | ||
|
||
def __init__(self, dim: int, eps: float = 1e-6): | ||
super().__init__() | ||
self.eps = eps | ||
self.weight = nn.Parameter(torch.ones(dim)) | ||
|
||
def _norm(self, x): | ||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | ||
|
||
def forward(self, x): | ||
output = self._norm(x.float()).type_as(x) | ||
return output * self.weight | ||
|
||
|
||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): | ||
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) | ||
t = torch.arange(end, device=freqs.device) # type: ignore | ||
freqs = torch.outer(t, freqs).float() # type: ignore | ||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | ||
return freqs_cis | ||
|
||
|
||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | ||
ndim = x.ndim | ||
assert 0 <= 1 < ndim | ||
assert freqs_cis.shape == (x.shape[1], x.shape[-1]) | ||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | ||
return freqs_cis.view(*shape) | ||
|
||
|
||
def apply_rotary_emb( | ||
xq: torch.Tensor, | ||
xk: torch.Tensor, | ||
freqs_cis: torch.Tensor, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | ||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | ||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | ||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | ||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | ||
return xq_out.type_as(xq), xk_out.type_as(xk) | ||
|
||
|
||
class Attention(nn.Module): | ||
|
||
def __init__(self, args: LLaMAConfig): | ||
super().__init__() | ||
|
||
self.n_local_heads = args.n_heads // gpc.get_world_size(ParallelMode.PARALLEL_1D) | ||
self.head_dim = args.dim // args.n_heads | ||
|
||
self.wq = Linear1D_Col( | ||
args.dim, | ||
args.n_heads * self.head_dim, | ||
bias=False, | ||
gather_output=False, | ||
init_method=lambda x: x, | ||
) | ||
self.wk = Linear1D_Col( | ||
args.dim, | ||
args.n_heads * self.head_dim, | ||
bias=False, | ||
gather_output=False, | ||
init_method=lambda x: x, | ||
) | ||
self.wv = Linear1D_Col( | ||
args.dim, | ||
args.n_heads * self.head_dim, | ||
bias=False, | ||
gather_output=False, | ||
init_method=lambda x: x, | ||
) | ||
self.wo = Linear1D_Row( | ||
args.n_heads * self.head_dim, | ||
args.dim, | ||
bias=False, | ||
input_is_parallel=True, | ||
init_method=lambda x: x, | ||
) | ||
|
||
self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda() | ||
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda() | ||
|
||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): | ||
bsz, seqlen, _ = x.shape | ||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) | ||
|
||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
|
||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) | ||
|
||
self.cache_k = self.cache_k.to(xq) | ||
self.cache_v = self.cache_v.to(xq) | ||
|
||
self.cache_k[:bsz, start_pos:start_pos + seqlen] = xk | ||
self.cache_v[:bsz, start_pos:start_pos + seqlen] = xv | ||
|
||
keys = self.cache_k[:bsz, :start_pos + seqlen] | ||
values = self.cache_v[:bsz, :start_pos + seqlen] | ||
|
||
xq = xq.transpose(1, 2) | ||
keys = keys.transpose(1, 2) | ||
values = values.transpose(1, 2) | ||
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) | ||
if mask is not None: | ||
scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) | ||
scores = F.softmax(scores.float(), dim=-1).type_as(xq) | ||
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) | ||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) | ||
|
||
return self.wo(output) | ||
|
||
|
||
class FeedForward(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
dim: int, | ||
hidden_dim: int, | ||
multiple_of: int, | ||
): | ||
super().__init__() | ||
hidden_dim = int(2 * hidden_dim / 3) | ||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | ||
|
||
self.w1 = Linear1D_Col(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x) | ||
self.w2 = Linear1D_Row(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x) | ||
self.w3 = Linear1D_Col(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x) | ||
|
||
def forward(self, x): | ||
return self.w2(F.silu(self.w1(x)) * self.w3(x)) | ||
|
||
|
||
class TransformerBlock(nn.Module): | ||
|
||
def __init__(self, layer_id: int, args: LLaMAConfig): | ||
super().__init__() | ||
self.n_heads = args.n_heads | ||
self.dim = args.dim | ||
self.head_dim = args.dim // args.n_heads | ||
self.attention = Attention(args) | ||
self.feed_forward = FeedForward(dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of) | ||
self.layer_id = layer_id | ||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||
|
||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): | ||
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask) | ||
out = h + self.feed_forward.forward(self.ffn_norm(h)) | ||
return out | ||
|
||
|
||
class LLaMA(nn.Module): | ||
|
||
def __init__(self, config: LLaMAConfig): | ||
super().__init__() | ||
self.config = config | ||
self.vocab_size = config.vocab_size | ||
self.n_layers = config.n_layers | ||
|
||
self.tok_embeddings = Embedding1D(config.vocab_size, config.dim, init_method=lambda x: x) | ||
|
||
self.layers = torch.nn.ModuleList() | ||
for layer_id in range(config.n_layers): | ||
self.layers.append(TransformerBlock(layer_id, config)) | ||
|
||
self.norm = RMSNorm(config.dim, eps=config.norm_eps) | ||
self.output = Linear1D_Col(config.dim, config.vocab_size, bias=False, init_method=lambda x: x) | ||
|
||
self.freqs_cis = precompute_freqs_cis(self.config.dim // self.config.n_heads, self.config.max_seq_len * 2) | ||
|
||
# @torch.inference_mode() | ||
def forward(self, tokens: torch.Tensor, start_pos: int): | ||
_bsz, seqlen = tokens.shape | ||
h = self.tok_embeddings(tokens) | ||
self.freqs_cis = self.freqs_cis.to(h.device) | ||
freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen] | ||
|
||
mask = None | ||
if seqlen > 1: | ||
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device) | ||
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) | ||
|
||
for layer in self.layers: | ||
h = layer(h, start_pos, freqs_cis, mask) | ||
h = self.norm(h) | ||
output = self.output(h[:, -1, :]) # only compute last logits | ||
return output.float() | ||
|
||
|
||
class LLaMAGenerator: | ||
|
||
def __init__(self, model: LLaMA, tokenizer: LLaMATokenizer): | ||
self.model = model | ||
self.tokenizer = tokenizer | ||
|
||
def generate( | ||
self, | ||
prompts: List[str], | ||
max_gen_len: int, | ||
temperature: float = 0.8, | ||
top_p: float = 0.95, | ||
) -> List[str]: | ||
bsz = len(prompts) | ||
config = self.model.config | ||
assert bsz <= config.max_batch_size, (bsz, config.max_batch_size) | ||
|
||
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] | ||
|
||
min_prompt_size = min([len(t) for t in prompt_tokens]) | ||
max_prompt_size = max([len(t) for t in prompt_tokens]) | ||
|
||
total_len = min(config.max_seq_len, max_gen_len + max_prompt_size) | ||
|
||
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long() | ||
for k, t in enumerate(prompt_tokens): | ||
tokens[k, :len(t)] = torch.tensor(t).long() | ||
input_text_mask = tokens != self.tokenizer.pad_id | ||
start_pos = min_prompt_size | ||
prev_pos = 0 | ||
for cur_pos in range(start_pos, total_len): | ||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) | ||
if temperature > 0: | ||
probs = torch.softmax(logits / temperature, dim=-1) | ||
next_token = sample_top_p(probs, top_p) | ||
else: | ||
next_token = torch.argmax(logits, dim=-1) | ||
next_token = next_token.reshape(-1) | ||
# only replace token if prompt has already been generated | ||
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) | ||
tokens[:, cur_pos] = next_token | ||
prev_pos = cur_pos | ||
|
||
decoded = [] | ||
for i, t in enumerate(tokens.tolist()): | ||
# cut to max gen len | ||
t = t[:len(prompt_tokens[i]) + max_gen_len] | ||
# cut to eos tok if any | ||
try: | ||
t = t[:t.index(self.tokenizer.eos_id)] | ||
except ValueError: | ||
pass | ||
decoded.append(self.tokenizer.decode(t)) | ||
return decoded | ||
|
||
|
||
def sample_top_p(probs, p): | ||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | ||
probs_sum = torch.cumsum(probs_sort, dim=-1) | ||
mask = probs_sum - probs_sort > p | ||
probs_sort[mask] = 0.0 | ||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | ||
next_token = torch.multinomial(probs_sort, num_samples=1) | ||
next_token = torch.gather(probs_idx, -1, next_token) | ||
return next_token | ||
|
||
|
||
def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int) -> LLaMA: | ||
start_time = time.time() | ||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) | ||
assert (world_size == len(checkpoints) | ||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" | ||
ckpt_path = checkpoints[local_rank] | ||
print("Loading") | ||
checkpoint = torch.load(ckpt_path, map_location="cpu") | ||
with open(Path(ckpt_dir) / "params.json", "r") as f: | ||
params = json.loads(f.read()) | ||
|
||
model_args: LLaMAConfig = LLaMAConfig(max_seq_len=1024, max_batch_size=32, **params) | ||
tokenizer = LLaMATokenizer(model_path=tokenizer_path) | ||
model_args.vocab_size = tokenizer.n_words | ||
torch.set_default_tensor_type(torch.cuda.HalfTensor) | ||
model = LLaMA(model_args) | ||
torch.set_default_tensor_type(torch.FloatTensor) | ||
model.load_state_dict(checkpoint, strict=False) | ||
|
||
generator = LLaMAGenerator(model, tokenizer) | ||
print(f"Loaded in {time.time() - start_time:.2f} seconds") | ||
return generator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from ..base import Actor | ||
from .llama import LLaMA, LLaMAConfig | ||
|
||
|
||
class LLaMAActor(Actor): | ||
""" | ||
BLOOM Actor model. | ||
Args: | ||
pretrained (str): Pretrained model name or path. | ||
config (BloomConfig): Model config. | ||
checkpoint (bool): Enable gradient checkpointing. | ||
lora_rank (int): LoRA rank. | ||
lora_train_bias (str): LoRA bias training mode. | ||
""" | ||
|
||
def __init__(self, config: Optional[LLaMAConfig] = None, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: | ||
|
||
model = LLaMA(config) | ||
|
||
super().__init__(model, lora_rank, lora_train_bias) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import Optional | ||
|
||
import torch.nn as nn | ||
|
||
from ..base import Critic | ||
from .llama import LLaMA, LLaMAConfig | ||
|
||
|
||
class LLaMACritic(Critic): | ||
""" | ||
GPT Critic model. | ||
Args: | ||
pretrained (str): Pretrained model name or path. | ||
config (GPT2Config): Model config. | ||
checkpoint (bool): Enable gradient checkpointing. | ||
""" | ||
|
||
def __init__(self, config: Optional[LLaMAConfig] = None, checkpoint: bool = False) -> None: | ||
|
||
model = LLaMA(config) | ||
value_head = nn.Linear(model.config.dim, 1) | ||
super().__init__(model, value_head) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import Optional | ||
|
||
import torch.nn as nn | ||
|
||
from ..base import RewardModel | ||
from .llama import LLaMA, LLaMAConfig | ||
|
||
|
||
class LLaMARM(RewardModel): | ||
""" | ||
GPT Critic model. | ||
Args: | ||
pretrained (str): Pretrained model name or path. | ||
config (GPT2Config): Model config. | ||
checkpoint (bool): Enable gradient checkpointing. | ||
""" | ||
|
||
def __init__(self, config: Optional[LLaMAConfig] = None, checkpoint: bool = False) -> None: | ||
|
||
model = LLaMA(config) | ||
value_head = nn.Linear(model.config.dim, 1) | ||
super().__init__(model, value_head) |
Oops, something went wrong.