Skip to content

Commit

Permalink
Minimum Trust Lamb (facebookresearch#1186)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#1186

Implementation of minimum trust Lamb as described in section 6.3 of https://arxiv.org/abs/1911.11423

Differential Revision: D18893828

fbshipit-source-id: 06a8db92eda746d9de3e6bd5f0e3ed34c8be3966
  • Loading branch information
Akshat Shrivastava authored and facebook-github-bot committed Dec 10, 2019
1 parent d46a90f commit f46135e
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion pytext/optimizer/lamb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Optional

import torch
from pytext.optimizer.optimizers import Optimizer
Expand All @@ -19,6 +20,7 @@ class Config(Optimizer.Config):
lr: float = 0.001
weight_decay: float = 0.00001
eps: float = 1e-8
min_trust: Optional[float] = None

@classmethod
def from_config(cls, config: Config, model: torch.nn.Module):
Expand All @@ -27,9 +29,18 @@ def from_config(cls, config: Config, model: torch.nn.Module):
lr=config.lr,
weight_decay=config.weight_decay,
eps=config.eps,
min_trust=config.min_trust,
)

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=0,
min_trust=None,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
Expand All @@ -44,6 +55,8 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0
{"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay},
)

self.min_trust = min_trust

def step(self, closure=None):
"""Performs a single optimization step.
Expand Down Expand Up @@ -104,6 +117,8 @@ def step(self, closure=None):
trust_ratio = 1
else:
trust_ratio = weight_norm / adam_norm
if self.min_trust:
trust_ratio = max(self.min_trust, trust_ratio)
state["weight_norm"] = weight_norm
state["adam_norm"] = adam_norm
state["trust_ratio"] = trust_ratio
Expand Down

0 comments on commit f46135e

Please sign in to comment.