Skip to content

Commit

Permalink
Make Dropout a no-op when p=0.0 (#259)
Browse files Browse the repository at this point in the history
Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
  • Loading branch information
epwalsh and dirkgr committed Sep 12, 2023
1 parent a33dbb0 commit a49f4ec
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@
log = logging.getLogger(__name__)


class Dropout(nn.Dropout):
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.p == 0.0:
return input
else:
return F.dropout(input, self.p, self.training, self.inplace)


class LayerNormBase(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
Expand Down Expand Up @@ -283,7 +291,7 @@ def __init__(self, layer_id: int, config: ModelConfig):
assert config.d_model % config.n_heads == 0

# Dropout.
self.dropout = nn.Dropout(config.residual_dropout)
self.dropout = Dropout(config.residual_dropout)

# Layer norms.
self.k_norm: Optional[LayerNormBase] = None
Expand Down Expand Up @@ -639,7 +647,7 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
wte=nn.Embedding(
config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
),
emb_drop=nn.Dropout(config.embedding_dropout),
emb_drop=Dropout(config.embedding_dropout),
blocks=nn.ModuleList([OlmoBlock.build(i, config) for i in range(config.n_layers)]),
ln_f=LayerNorm.build(config),
)
Expand Down

0 comments on commit a49f4ec

Please sign in to comment.