-
Notifications
You must be signed in to change notification settings - Fork 121
/
yogi.py
executable file
·36 lines (26 loc) · 1.07 KB
/
yogi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import numpy as np
class YoGi:
def __init__(self, eta=1e-2, tau=1e-3, beta=0.9, beta2=0.99):
self.eta = eta
self.tau = tau
self.beta = beta
self.v_t = []
self.m_t = []
self.beta2 = beta2
def update(self, gradients):
update_gradients = []
if not self.v_t:
self.v_t = [torch.full_like(g, self.tau) for g in gradients]
self.m_t = [torch.full_like(g, 0.0) for g in gradients]
for idx, gradient in enumerate(gradients):
gradient_square = gradient**2
self.m_t[idx] = self.beta * self.m_t[idx] + (1.0 - self.beta) * gradient
self.v_t[idx] = self.v_t[idx] - (
1.0 - self.beta2
) * gradient_square * torch.sign(self.v_t[idx] - gradient_square)
yogi_learning_rate = self.eta / (torch.sqrt(self.v_t[idx]) + self.tau)
update_gradients.append(yogi_learning_rate * self.m_t[idx])
if len(update_gradients) == 0:
update_gradients = gradients
return update_gradients