In [89]:
import torch
from torch import Tensor
from torch.optim import Optimizer
from typing import List, Optional

In [90]:
x = torch.randn([1, 100, 3])
x

tensor([[[ 0.1477,  1.1761,  0.2385],
         [ 0.9800, -0.5217,  0.1854],
         [-0.7014, -0.3157,  0.0058],
         [-1.2235,  0.5164, -0.1806],
         [-1.1853,  0.4897, -1.5347],
         [-0.7726, -0.1140, -2.1939],
         [-0.5408, -0.0099,  0.6673],
         [ 2.4713,  0.6673,  1.4000],
         [-0.7779,  0.5023, -1.3360],
         [ 0.5957,  1.3773, -0.3514],
         [-0.5962, -0.1911,  1.0068],
         [-1.4420, -0.9462,  0.8203],
         [ 2.5332,  1.6473,  0.2436],
         [ 0.3981, -1.0504,  1.0002],
         [-0.0571,  1.0098, -0.6806],
         [-0.8278, -0.9173,  1.5652],
         [ 1.1461, -0.2537, -1.3157],
         [ 1.7280, -0.5511, -1.6518],
         [ 0.2686, -0.5983,  0.9149],
         [-0.1657, -0.5920,  1.4845],
         [ 0.3906, -2.1320, -0.6200],
         [ 1.7664, -1.2214,  0.3223],
         [ 0.5007, -1.1563,  0.9554],
         [-0.6403,  0.4780, -0.6033],
         [ 0.9015,  0.7944, -0.2937],
         [-0.4114, -0.6812, -1.9313],
         [ 0

In [91]:
target = torch.randn([1, 100, 64])
target

tensor([[[ 1.0725, -1.0301,  1.8985,  ..., -0.4468, -0.8776,  1.3075],
         [-0.8638, -0.1594, -2.1663,  ...,  0.6597,  0.3633, -0.6196],
         [ 0.5475, -0.0219, -0.9264,  ..., -1.0786,  1.7337, -0.9595],
         ...,
         [-0.2996,  1.1045,  1.3370,  ...,  1.2970, -0.1408, -0.1617],
         [ 0.1590, -1.5545, -1.7229,  ..., -0.8568,  0.2522, -0.8653],
         [ 0.0266, -2.0725,  0.3277,  ..., -0.3328,  0.7427, -0.5293]]])

In [92]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(3, 64)  # w*x+b, w - weight, b - bias

    def forward(self, x):  # __call__ -> forward()
        x = self.fc1(x)
        return x

In [93]:
model = Net()

In [94]:
model(x)

tensor([[[ 6.4221e-02,  5.8706e-01,  1.0781e-03,  ...,  4.2527e-01,
           6.9483e-01,  8.0912e-01],
         [ 6.1115e-01, -1.2252e-01, -1.3593e+00,  ...,  5.1047e-01,
          -5.1240e-01,  4.1108e-01],
         [ 4.0667e-01, -4.2657e-01, -3.0717e-01,  ...,  5.8798e-01,
           1.9267e-01,  2.1620e-01],
         ...,
         [ 8.2103e-01, -9.1886e-02, -1.4527e+00,  ...,  1.0664e+00,
           1.7230e-01,  7.2662e-02],
         [ 1.0773e+00, -6.0008e-01, -1.7030e+00,  ...,  1.3178e+00,
           7.6989e-02, -3.3255e-01],
         [ 8.0428e-01, -1.4907e-01, -1.1929e+00,  ...,  1.1946e+00,
           4.9957e-01, -3.6078e-02]]], grad_fn=<ViewBackward0>)

In [95]:
model(x).shape

torch.Size([1, 100, 64])

In [96]:
criterion = torch.nn.MSELoss()

In [101]:
class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=0.0003, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, *, maximize=False, foreach: Optional[bool] = None):

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov,
                        maximize=maximize, foreach=foreach)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGD, self).__init__(params, defaults)

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)
            group.setdefault('maximize', False)
            group.setdefault('foreach', None)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            d_p_list = []
            momentum_buffer_list = []
            has_sparse_grad = False

            for p in group['params']:
                print('Param:', p)
                if p.grad is not None:
                    params_with_grad.append(p)
                    d_p_list.append(p.grad)
                    if p.grad.is_sparse:
                        has_sparse_grad = True

                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        momentum_buffer_list.append(None)
                    else:
                        momentum_buffer_list.append(state['momentum_buffer'])

            sgd(params_with_grad,
                d_p_list,
                momentum_buffer_list,
                weight_decay=group['weight_decay'],
                momentum=group['momentum'],
                lr=group['lr'],
                dampening=group['dampening'],
                nesterov=group['nesterov'],
                maximize=group['maximize'],
                has_sparse_grad=has_sparse_grad,
                foreach=group['foreach'])

            # update momentum_buffers in state
            for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
                state = self.state[p]
                state['momentum_buffer'] = momentum_buffer

        return loss

def sgd(params: List[Tensor],
        d_p_list: List[Tensor],
        momentum_buffer_list: List[Optional[Tensor]],
        # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
        # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
        has_sparse_grad: bool = None,
        foreach: bool = None,
        *,
        weight_decay: float,
        momentum: float,
        lr: float,
        dampening: float,
        nesterov: bool,
        maximize: bool):
    r"""Functional API that performs SGD algorithm computation.

    See :class:`~torch.optim.SGD` for details.
    """

    if foreach is None:
        # Placeholder for more complex foreach logic to be added when value is not set
        foreach = False

    if foreach and torch.jit.is_scripting():
        raise RuntimeError('torch.jit.script not supported with foreach optimizers')

    if foreach and not torch.jit.is_scripting():
        func = _multi_tensor_sgd
    else:
        func = _single_tensor_sgd

    func(params,
         d_p_list,
         momentum_buffer_list,
         weight_decay=weight_decay,
         momentum=momentum,
         lr=lr,
         dampening=dampening,
         nesterov=nesterov,
         has_sparse_grad=has_sparse_grad,
         maximize=maximize)

def _single_tensor_sgd(params: List[Tensor],
                       d_p_list: List[Tensor],
                       momentum_buffer_list: List[Optional[Tensor]],
                       *,
                       weight_decay: float,
                       momentum: float,
                       lr: float,
                       dampening: float,
                       nesterov: bool,
                       maximize: bool,
                       has_sparse_grad: bool):

    for i, param in enumerate(params):

        d_p = d_p_list[i]
        if weight_decay != 0:
            d_p = d_p.add(param, alpha=weight_decay)

        if momentum != 0:
            buf = momentum_buffer_list[i]

            if buf is None:
                buf = torch.clone(d_p).detach()
                momentum_buffer_list[i] = buf
            else:
                buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

            if nesterov:
                d_p = d_p.add(buf, alpha=momentum)
            else:
                d_p = buf

        alpha = lr if maximize else -lr
        param.add_(d_p, alpha=alpha)


def _multi_tensor_sgd(params: List[Tensor],
                      grads: List[Tensor],
                      momentum_buffer_list: List[Optional[Tensor]],
                      *,
                      weight_decay: float,
                      momentum: float,
                      lr: float,
                      dampening: float,
                      nesterov: bool,
                      maximize: bool,
                      has_sparse_grad: bool):

    if len(params) == 0:
        return

    if has_sparse_grad is None:
        has_sparse_grad = any([grad.is_sparse for grad in grads])

    if weight_decay != 0:
        grads = torch._foreach_add(grads, params, alpha=weight_decay)

    if momentum != 0:
        bufs = []

        all_states_with_momentum_buffer = True
        for i in range(len(momentum_buffer_list)):
            if momentum_buffer_list[i] is None:
                all_states_with_momentum_buffer = False
                break
            else:
                bufs.append(momentum_buffer_list[i])

        if all_states_with_momentum_buffer:
            torch._foreach_mul_(bufs, momentum)
            torch._foreach_add_(bufs, grads, alpha=1 - dampening)
        else:
            bufs = []
            for i in range(len(momentum_buffer_list)):
                if momentum_buffer_list[i] is None:
                    buf = momentum_buffer_list[i] = torch.clone(grads[i]).detach()
                else:
                    buf = momentum_buffer_list[i]
                    buf.mul_(momentum).add_(grads[i], alpha=1 - dampening)

                bufs.append(buf)

        if nesterov:
            torch._foreach_add_(grads, bufs, alpha=momentum)
        else:
            grads = bufs

    alpha = lr if maximize else -lr
    if not has_sparse_grad:
        torch._foreach_add_(params, grads, alpha=alpha)
    else:
        # foreach APIs dont support sparse
        for i in range(len(params)):
            params[i].add_(grads[i], alpha=alpha)

In [102]:
optimizer = SGD(model.parameters(), lr=0.0003)

In [103]:
for epoch in range(1):
    output = model(x)
    loss = criterion(output, target)
    print('Epoch: ', epoch, 'Loss: ', loss.item())
    optimizer.zero_grad()  # w.grad, b.grad - reset to zero
    loss.backward()  # update w.grad, b.grad
    optimizer.step()

Epoch:  0 Loss:  1.4261630773544312
Param: Parameter containing:
tensor([[ 0.0550, -0.3038,  0.2754],
        [ 0.2229,  0.5208,  0.2039],
        [-0.4984,  0.5737, -0.5341],
        [-0.3640, -0.0862, -0.4289],
        [-0.3080, -0.3587,  0.1133],
        [ 0.2690, -0.2616,  0.2324],
        [ 0.5185, -0.4858,  0.4483],
        [ 0.2765,  0.3116, -0.3453],
        [ 0.4495,  0.0807,  0.5607],
        [ 0.5556,  0.5078,  0.4169],
        [ 0.0366,  0.3227, -0.1193],
        [ 0.4773,  0.0684,  0.4565],
        [-0.3486, -0.5000,  0.0359],
        [-0.1063,  0.5527,  0.0112],
        [ 0.0617,  0.5012,  0.3767],
        [-0.1005, -0.0750,  0.1014],
        [-0.0017, -0.2673,  0.5201],
        [ 0.4541,  0.0377, -0.1670],
        [ 0.0227, -0.1711, -0.0209],
        [-0.0901, -0.4105, -0.3948],
        [ 0.3868, -0.0460,  0.3663],
        [ 0.2310,  0.1717, -0.0698],
        [-0.5369,  0.3327, -0.1426],
        [ 0.1396,  0.0988,  0.3782],
        [ 0.3268, -0.3317,  0.3050],
        [ 