Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,83 @@ def __init__(self, session_len=8192, **kwargs):
self.session_len = session_len


@MODELS.register_module(name='baichuan-7b')
class Baichuan7B(BaseModel):
@MODELS.register_module(name='baichuan')
class Baichuan(BaseModel):

def __init__(self, repetition_penalty=1.1, **kwargs):
super().__init__(**kwargs)
self.repetition_penalty = repetition_penalty


@MODELS.register_module(name='baichuan-chat')
class BaichuanChat(BaseModel):

def __init__(self,
repetition_penalty=1.1,
user_token='<reserved_102>',
assistant_token='<reserved_103>',
temperature=0.3,
top_k=5,
top_p=0.85,
**kwargs):
super().__init__(**kwargs)
self.repetition_penalty = repetition_penalty
self.user_token = user_token
self.assistant_token = assistant_token
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p

def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
return f'{self.user_token}{prompt}{self.assistant_token}'
else:
return f'{self.user_token}{prompt}{self.assistant_token}'

def messages2prompt(self, messages, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.

Args:
messages (str | List): user's input prompt
sequence_start (bool): flag to start the sequence
Returns:
str: the concatenated prompt
"""
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
system = '' if not system else system
ret = f'{system}'
for user, assistant in zip(users, assistants):
if assistant:
ret += f'{self.user_token}{user}{self.assistant_token}' \
f'{assistant}'
else:
ret += f'{self.user_token}{user}{self.assistant_token}'
return ret


@MODELS.register_module(name='baichuan2-chat')
class Baichuan2Chat(BaichuanChat):

def __init__(self,
repetition_penalty=1.1,
user_token='<reserved_106>',
assistant_token='<reserved_107>',
temperature=0.3,
top_k=5,
top_p=0.85,
**kwargs):
super().__init__(**kwargs)
self.repetition_penalty = repetition_penalty
self.user_token = user_token
self.assistant_token = assistant_token
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p


@MODELS.register_module(name='puyu')
class Puyu(BaseModel):
"""Chat template of puyu model.This is only for internal usage in Shanghai
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch_poc/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main(
interface.

Args:
model_path (str): the path of the deployed model
model_path (str): the huggingface model path
session_id (int): the identical id of a session
repetition_penalty (float): parameter to penalize repetition
tp (int): GPU number used in tensor parallelism
Expand Down
8 changes: 6 additions & 2 deletions lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ def __init__(self,
tp: int = 1) -> None:

self.tp = tp
hf_config = AutoConfig.from_pretrained(model_path)
hf_config = AutoConfig.from_pretrained(model_path,
trust_remote_code=True)
torch_dtype = getattr(hf_config, 'torch_dtype', 'float16')
torch_dtype = eval(f'torch.{torch_dtype}')
self.torch_dtype = torch_dtype
Expand Down Expand Up @@ -374,7 +375,9 @@ def __init__(self,
if tp == 1:
with LoadNoInit():
hf_model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype='auto')
model_path,
torch_dtype=torch_dtype,
trust_remote_code=True)
hf_model.eval()

self.patched_model = patch(hf_model,
Expand Down Expand Up @@ -605,6 +608,7 @@ def step(self, return_logits=False):
TemperatureLogitsWarper(param.temperature),
])
logit = logits_processor(input_ids, logit)
logit = logit.reshape([-1, logit.shape[-1]])
next_token_ids.append(logit[-1].argmax())

# update scheduler
Expand Down
200 changes: 200 additions & 0 deletions lmdeploy/pytorch_poc/patch/baichuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union

import torch
from torch import nn
from transformers.modeling_outputs import CausalLMOutputWithPast

from lmdeploy.pytorch_poc.kernels import paged_attention_fwd

from .llama import apply_rotary_pos_emb


class BaichuanAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper."""

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
if self.context.use_origin:
return self.origin_mod(hidden_states, attention_mask, position_ids,
past_key_value, output_attentions,
use_cache)
else:
return self._contiguous_batching_forward(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache)

def _contiguous_batching_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
assert not output_attentions
origin_self = self.origin_mod

context = self.context.context
history_lengths = context.history_lengths
max_seq_len = position_ids.size(-1)

proj = origin_self.W_pack(hidden_states)
proj = proj.unflatten(
-1, (3, origin_self.hidden_size)).unsqueeze(0).transpose(
0, -2).squeeze(-2)
query_states = proj[0].view(-1, origin_self.num_heads,
origin_self.head_dim)
key_states = proj[1].view(-1, origin_self.num_heads,
origin_self.head_dim)
value_states = proj[2].view(-1, origin_self.num_heads,
origin_self.head_dim)

kv_seq_len = max_seq_len + max(history_lengths)
# TODO: setup past_key_value with paged attention
if hasattr(origin_self,
'rotary_emb'): # baichuan-13B has no rotary_emb
cos, sin = origin_self.rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos.to(value_states.dtype) # baichuan2 hard-coded it float32
sin = sin.to(value_states.dtype)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)

kv_seq_length = position_ids[..., -1] + 1
q_seq_length = kv_seq_length - kv_seq_length.new_tensor(
history_lengths)
q_start_loc = q_seq_length.cumsum(0)
q_start_loc = torch.cat([q_start_loc.new_zeros(1), q_start_loc[:-1]])
context.fill_cache(
key_states,
value_states,
q_start_loc,
q_seq_length,
past_key_value[0],
past_key_value[1],
)
attn_output = torch.empty_like(query_states)

block_offsets = context.block_offsets
block_size = past_key_value[0].size(1)
paged_attention_fwd(query_states,
past_key_value[0],
past_key_value[1],
attn_output,
block_offsets,
b_start_loc=q_start_loc,
b_seq_len=q_seq_length,
b_kv_seq_len=kv_seq_length,
max_input_len=max_seq_len,
BLOCK=block_size)
attn_output = attn_output.reshape(-1, origin_self.hidden_size)

attn_output = origin_self.o_proj(attn_output)
if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


class BaichuanLayer(torch.nn.Module):
Copy link
Collaborator

@wangruohui wangruohui Sep 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this modification is only for allowing position ids passed into attention to compute q_locs and q_seq_lens for paged attention. Since these two are crucial for paged attention, I think we can pass these two values consistently through the context to avoid copying/pasting a lot of codes just for changing the argument list.
I have made this work in the chatglm part, change in patch, change in input. If all (cc @grimoire ) think this is a good idea, I can make a standalone PR for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, context is designed to shared data between all rewrite modules.


def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
torch.FloatTensor]]]:

residual = hidden_states

hidden_states = self.origin_mod.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.origin_mod.self_attn( # noqa
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.origin_mod.post_attention_layernorm(hidden_states)
hidden_states = self.origin_mod.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states, )

if use_cache:
outputs += (present_key_value, )

return outputs


class BaichuanForCausalLM(nn.Module):

def forward(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this rewrite for https://github.com/baichuan-inc/Baichuan-7B/blob/6f3ef4633a90c2d8a3e0763d0dec1b8dc11588f5/models/modeling_baichuan.py#L581? I did not see much difference between this one and the origin one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for 13B

input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
**kwargs) -> Union[Tuple, CausalLMOutputWithPast]:

return_dict = return_dict if return_dict is not None else self.origin_mod.config.use_return_dict # noqa

# decoder outputs consists of
# (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.origin_mod.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=kwargs.get('position_ids', None),
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = outputs[0]
logits = self.origin_mod.lm_head(hidden_states)

loss = None

if not return_dict:
output = (logits, ) + outputs[1:]
return (loss, ) + output if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
23 changes: 20 additions & 3 deletions lmdeploy/pytorch_poc/patch/patch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import inspect
import re
from copy import copy
from typing import Dict, Sequence

Expand All @@ -22,10 +23,26 @@
'transformers.models.llama.modeling_llama.LlamaRMSNorm':
'lmdeploy.pytorch_poc.patch.llama.LlamaRMSNorm',
'transformers.models.llama.modeling_llama.LlamaDecoderLayer':
'lmdeploy.pytorch_poc.patch.llama.LlamaDecoderLayer'
'lmdeploy.pytorch_poc.patch.llama.LlamaDecoderLayer',
'transformers_modules\.(.*)\.modeling_baichuan.(.*)Model': # noqa
'lmdeploy.pytorch_poc.patch.llama.LlamaModel',
'transformers_modules\.(.*)\.modeling_baichuan.(.*)Attention': # noqa
'lmdeploy.pytorch_poc.patch.baichuan.BaichuanAttention',
'transformers_modules\.(.*)\.modeling_baichuan.BaichuanForCausalLM': # noqa
'lmdeploy.pytorch_poc.patch.baichuan.BaichuanForCausalLM',
'transformers_modules\.(.*)\.modeling_baichuan.BaichuanLayer': # noqa
'lmdeploy.pytorch_poc.patch.baichuan.BaichuanLayer',
}


def _get_rewrite_qualname(origin_qualname: str):
global MODULE_MAP
for key, value in MODULE_MAP.items():
if re.search(key, origin_qualname):
return value
return None


def _class_from_qualname(qualname):
last_dot = qualname.rfind('.')
modname = qualname[:last_dot]
Expand All @@ -52,11 +69,11 @@ def _patch(model: torch.nn.Module, context: Addict):
module_name = inspect.getmodule(model).__name__
class_name = model.__class__.__name__
origin_qualname = f'{module_name}.{class_name}'
rewrite_qualname = MODULE_MAP.get(origin_qualname, None)
rewrite_qualname = _get_rewrite_qualname(origin_qualname)

if rewrite_qualname is None:
origin_qualname = class_name
rewrite_qualname = MODULE_MAP.get(origin_qualname, None)
rewrite_qualname = _get_rewrite_qualname(origin_qualname)

if rewrite_qualname is not None:
logger.debug(
Expand Down