Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate out the biases #1156

Closed
wants to merge 11 commits into from
4 changes: 3 additions & 1 deletion config_hub/pretrain/tinystories.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ model_config:
head_size: 48
rotary_percentage: 1.0
parallel_residual: false
bias: false
attn_qkv_bias: false
attn_proj_bias: false
mlp_bias: false
norm_class_name: RMSNorm
mlp_class_name: LLaMAMLP
intermediate_size: 768
Expand Down
14 changes: 7 additions & 7 deletions litgpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,10 @@ def __init__(self, config: Config, block_idx: int) -> None:
nn.Module.__init__(self)
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
# key, query, value projections for all heads, but in a batch
self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias)
self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.attn_qkv_bias)
# output projection
# if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.attn_proj_bias)
# disabled by default
self.kv_cache: Optional[KVCache] = None

Expand Down Expand Up @@ -157,8 +157,8 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa
class GptNeoxMLP(litgpt.model.GptNeoxMLP):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.mlp_bias)
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.mlp_bias)

self.config = config

Expand All @@ -177,9 +177,9 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa
class LLaMAMLP(litgpt.model.LLaMAMLP):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.mlp_bias)
self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.mlp_bias)
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.mlp_bias)

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
Expand Down
Loading
Loading