In [7]:
import torch
import torch.nn as nn 
from torch.optim.optimizer import Optimizer, required

class КуикПроп(Optimizer):
    def __init__(self, params, lr=1e-3, weight_decay=0, eps=1e-10):
      
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
            
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
            
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, weight_decay=weight_decay, eps=eps)
       
        super(КуикПроп, self).__init__(params, defaults)

#     def __setstate__(self, state):
#         super(КуикПроп, self).__setstate__(state)
#         for group in self.param_groups:
#             group.setdefault('nesterov', False)
            
    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                state = self.state[p]
                if len(state) == 0:
                    state['step'] = 0
                    state['gradprev'] = torch.empty_like(p.data).uniform_(0, 1)
                    state['deltap']   = torch.ones_like(p.data)
                if group['weight_decay'] != 0:
                    grad = grad.add(p, alpha=weight_decay)
                
                state['step'] += 1
                deltaw1 = state['deltap'] * (grad / (state['gradprev'] - grad + group['eps']) )
                state['gradprev'] = grad.clone()
                p.data += group['lr'] * deltaw1
                state['deltap'] = deltaw1.clone()
          
        return loss

In [8]:
device = torch.device('cpu')
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)

model = nn.Sequential(
        nn.Linear(D_in, H),
        nn.ReLU(),
        nn.Linear(H, D_out)
        ).to(device)

optimizer = КуикПроп(model.parameters(),lr=1e-4)
criterion = nn.MSELoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [9]:
# %%timeit
for i in range(1000):
    output = model(x)
    loss   = criterion(output,y) 
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(loss.item())
    

1.0835081338882446
1.0839401483535767
123.95174407958984
6.305595397949219
130.79318237304688
379.08709716796875
166.62191772460938
43568.4140625
334.9774169921875
723.9132690429688
76227.578125
22345.25
8838.640625
12660.6337890625
15517.7783203125
13134.927734375
15212.740234375
417370.78125
122580.140625
118590808.0
369884.4375
742924.0
162289.40625
4880400.5
3987435.25
130124560.0
7671061.5
3553849.5
4319557.5
2058356.25
7741330.5
75269904.0
1556340.0
1182133.0
210064.671875
445252.1875
237415.671875
161400.21875
303447.4375
603438.375
14265786368.0
366042.125
263845.5625
271932.0
83890.140625
299404.5
755561.125
246152.921875
142928.875
234435.34375
57349.0859375
122231.25
3413188.75
154424.578125
680588.1875
2749511.25
222728.671875
1636042.5
176158.65625
230846.90625
3170016.25
45935984.0
1023646.625
1198953.625
3089082.25
2391667.0
301282.40625
675741440.0
482498.9375
1386890.25
2255356.25
1865859.75
39080736.0
453010944.0
2258581504.0
41407528.0
113061808.0
18181894.0
4682790.

335.09356689453125
165.9208221435547
209226.71875
131.13946533203125
120.2416000366211
552.4688110351562
123.74565124511719
321.7882385253906
538.595458984375
116.58515930175781
142.04330444335938
98.99290466308594
177.76498413085938
185.86227416992188
169.98776245117188
233.5987548828125
178.05535888671875
744.8494873046875
43010.53515625
2807651.25
716.0108642578125
369.886962890625
107.12835693359375
798255.1875
103.81825256347656
104.34770202636719
230.057373046875
90.8101806640625
92.26609802246094
83.9195785522461
1233.37890625
120.27449035644531
141.70687866210938
330.37738037109375
82.81851959228516
756.8068237304688
199.6038360595703
317.87091064453125
132.0022735595703
89.5228271484375
79.75141906738281
99.5165023803711
78.353759765625
271.3082580566406
123.6045150756836
94.8426284790039
90.21927642822266
77.29061126708984
1955.868408203125
2388.0615234375
92.02392578125
66.22833251953125
111.07962799072266
159.90023803710938
1721.836181640625
79.71249389648438
59.75402832031