Skip to content

Commit

Permalink
[booster] init module structure and definition (hpcaitech#3056)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee authored and Fazziekey committed Mar 9, 2023
1 parent faa8526 commit 4bd3ebc
Show file tree
Hide file tree
Showing 11 changed files with 600 additions and 0 deletions.
316 changes: 316 additions & 0 deletions applications/ChatGPT/chatgpt/models/llama/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

import json
import math
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn

from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer import Embedding1D, Linear1D_Col, Linear1D_Row

from .llama_tokenizer import LLaMATokenizer


@dataclass
class LLaMAConfig:
dim: int = 512
n_layers: int = 8
n_heads: int = 8
vocab_size: int = -1 # defined later by tokenizer
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
norm_eps: float = 1e-5

max_batch_size: int = 32
max_seq_len: int = 1024


class RMSNorm(nn.Module):

def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)


class Attention(nn.Module):

def __init__(self, args: LLaMAConfig):
super().__init__()

self.n_local_heads = args.n_heads // gpc.get_world_size(ParallelMode.PARALLEL_1D)
self.head_dim = args.dim // args.n_heads

self.wq = Linear1D_Col(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = Linear1D_Col(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = Linear1D_Col(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = Linear1D_Row(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)

self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda()
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda()

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)

self.cache_k[:bsz, start_pos:start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos:start_pos + seqlen] = xv

keys = self.cache_k[:bsz, :start_pos + seqlen]
values = self.cache_v[:bsz, :start_pos + seqlen]

xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

return self.wo(output)


class FeedForward(nn.Module):

def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

self.w1 = Linear1D_Col(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self.w2 = Linear1D_Row(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
self.w3 = Linear1D_Col(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))


class TransformerBlock(nn.Module):

def __init__(self, layer_id: int, args: LLaMAConfig):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out


class LLaMA(nn.Module):

def __init__(self, config: LLaMAConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.n_layers = config.n_layers

self.tok_embeddings = Embedding1D(config.vocab_size, config.dim, init_method=lambda x: x)

self.layers = torch.nn.ModuleList()
for layer_id in range(config.n_layers):
self.layers.append(TransformerBlock(layer_id, config))

self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = Linear1D_Col(config.dim, config.vocab_size, bias=False, init_method=lambda x: x)

self.freqs_cis = precompute_freqs_cis(self.config.dim // self.config.n_heads, self.config.max_seq_len * 2)

# @torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]

mask = None
if seqlen > 1:
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h[:, -1, :]) # only compute last logits
return output.float()


class LLaMAGenerator:

def __init__(self, model: LLaMA, tokenizer: LLaMATokenizer):
self.model = model
self.tokenizer = tokenizer

def generate(
self,
prompts: List[str],
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
) -> List[str]:
bsz = len(prompts)
config = self.model.config
assert bsz <= config.max_batch_size, (bsz, config.max_batch_size)

prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

min_prompt_size = min([len(t) for t in prompt_tokens])
max_prompt_size = max([len(t) for t in prompt_tokens])

total_len = min(config.max_seq_len, max_gen_len + max_prompt_size)

tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
for k, t in enumerate(prompt_tokens):
tokens[k, :len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id
start_pos = min_prompt_size
prev_pos = 0
for cur_pos in range(start_pos, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
prev_pos = cur_pos

decoded = []
for i, t in enumerate(tokens.tolist()):
# cut to max gen len
t = t[:len(prompt_tokens[i]) + max_gen_len]
# cut to eos tok if any
try:
t = t[:t.index(self.tokenizer.eos_id)]
except ValueError:
pass
decoded.append(self.tokenizer.decode(t))
return decoded


def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token


def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert (world_size == len(checkpoints)
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
ckpt_path = checkpoints[local_rank]
print("Loading")
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())

model_args: LLaMAConfig = LLaMAConfig(max_seq_len=1024, max_batch_size=32, **params)
tokenizer = LLaMATokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = LLaMA(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)

generator = LLaMAGenerator(model, tokenizer)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return generator
25 changes: 25 additions & 0 deletions applications/ChatGPT/chatgpt/models/llama/llama_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Optional

import torch

from ..base import Actor
from .llama import LLaMA, LLaMAConfig


class LLaMAActor(Actor):
"""
BLOOM Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self, config: Optional[LLaMAConfig] = None, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:

model = LLaMA(config)

super().__init__(model, lora_rank, lora_train_bias)
23 changes: 23 additions & 0 deletions applications/ChatGPT/chatgpt/models/llama/llama_critic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Optional

import torch.nn as nn

from ..base import Critic
from .llama import LLaMA, LLaMAConfig


class LLaMACritic(Critic):
"""
GPT Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
"""

def __init__(self, config: Optional[LLaMAConfig] = None, checkpoint: bool = False) -> None:

model = LLaMA(config)
value_head = nn.Linear(model.config.dim, 1)
super().__init__(model, value_head)
23 changes: 23 additions & 0 deletions applications/ChatGPT/chatgpt/models/llama/llama_rm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Optional

import torch.nn as nn

from ..base import RewardModel
from .llama import LLaMA, LLaMAConfig


class LLaMARM(RewardModel):
"""
GPT Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
"""

def __init__(self, config: Optional[LLaMAConfig] = None, checkpoint: bool = False) -> None:

model = LLaMA(config)
value_head = nn.Linear(model.config.dim, 1)
super().__init__(model, value_head)
Loading

0 comments on commit 4bd3ebc

Please sign in to comment.