Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions llmc/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ def find_blocks(self):

def find_embed_layers(self):
self.word_embeddings = self.model.transformer.word_embeddings
self.rotary_emb = self.model.model.rotary_emb

def find_block_name(self):
self.block_name_prefix = 'model.transformer.h'

def get_embed_layers(self):
return [self.word_embeddings]
return [self.word_embeddings, self.rotary_emb]

def get_layers_except_blocks(self):
return [self.word_embeddings, self.model.transformer.ln_f]
return [self.word_embeddings, self.rotary_emb, self.model.transformer.ln_f]

def has_bias(self):
return False
Expand Down
5 changes: 3 additions & 2 deletions llmc/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ def find_blocks(self):

def find_embed_layers(self):
self.embed_tokens = self.model.model.embed_tokens
self.rotary_emb = self.model.model.rotary_emb

def find_block_name(self):
self.block_name_prefix = 'model.layers'
self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'}

def get_embed_layers(self):
return [self.embed_tokens]
return [self.embed_tokens, self.rotary_emb]

def get_head_layers(self):
return [self.model.lm_head]
Expand All @@ -28,7 +29,7 @@ def get_pre_head_layernorm_layers(self):
return [self.model.model.norm]

def get_layers_except_blocks(self):
return [self.embed_tokens, self.model.model.norm, self.model.lm_head]
return [self.embed_tokens, self.rotary_emb, self.model.model.norm, self.model.lm_head] # noqa

def skip_layer_name(self):
return ['lm_head']
Expand Down
5 changes: 3 additions & 2 deletions llmc/models/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ def find_blocks(self):

def find_embed_layers(self):
self.embed_tokens = self.model.model.embed_tokens
self.rotary_emb = self.model.model.rotary_emb

def find_block_name(self):
self.block_name_prefix = 'model.layers'
self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'}

def get_embed_layers(self):
return [self.embed_tokens]
return [self.embed_tokens, self.rotary_emb]

def get_head_layers(self):
return [self.model.lm_head]
Expand All @@ -28,7 +29,7 @@ def get_pre_head_layernorm_layers(self):
return [self.model.model.final_layernorm]

def get_layers_except_blocks(self):
return [self.embed_tokens, self.model.model.final_layernorm, self.model.lm_head]
return [self.embed_tokens, self.rotary_emb, self.model.model.final_layernorm, self.model.lm_head] # noqa

def skip_layer_name(self):
return ['lm_head']
Expand Down
5 changes: 3 additions & 2 deletions llmc/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ def find_blocks(self):

def find_embed_layers(self):
self.embed_tokens = self.model.model.embed_tokens
self.rotary_emb = self.model.model.rotary_emb

def find_block_name(self):
self.block_name_prefix = 'model.layers'
self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'}

def get_embed_layers(self):
return [self.embed_tokens]
return [self.embed_tokens, self.rotary_emb]

def get_head_layers(self):
return [self.model.lm_head]
Expand All @@ -28,7 +29,7 @@ def get_pre_head_layernorm_layers(self):
return [self.model.model.norm]

def get_layers_except_blocks(self):
return [self.embed_tokens, self.model.model.norm, self.model.lm_head]
return [self.embed_tokens, self.rotary_emb, self.model.model.norm, self.model.lm_head] # noqa

def skip_layer_name(self):
return ['lm_head']
Expand Down
5 changes: 3 additions & 2 deletions llmc/models/stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False):

def find_blocks(self):
self.blocks = self.model.model.layers
self.rotary_emb = self.model.model.rotary_emb

def find_embed_layers(self):
self.embed_tokens = self.model.model.embed_tokens
Expand All @@ -19,7 +20,7 @@ def find_block_name(self):
self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'}

def get_embed_layers(self):
return [self.embed_tokens]
return [self.embed_tokens, self.rotary_emb]

def get_head_layers(self):
return [self.model.lm_head]
Expand All @@ -28,7 +29,7 @@ def get_pre_head_layernorm_layers(self):
return [self.model.model.norm]

def get_layers_except_blocks(self):
return [self.embed_tokens, self.model.model.norm, self.model.lm_head]
return [self.embed_tokens, self.rotary_emb, self.model.model.norm, self.model.lm_head] # noqa

def skip_layer_name(self):
return ['lm_head']
Expand Down
2 changes: 1 addition & 1 deletion requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
torch>=2.1.0
pillow
loguru
transformers>=4.41.2
transformers==4.45.2
huggingface-hub
sentencepiece
protobuf
Expand Down
Loading