In [None]:
# SparseSGDM optimizer: SGD with momentum + support for per-parameter gradient masks.
# Usage:
#   optimizer = SparseSGDM(model.parameters(), lr=0.1, momentum=0.9, gradient_masks=masks)
# where `masks` is a list of tensors, one per parameter (same ordering as model.parameters()).
# You can also set masks later with optimizer.set_param_masks(list_of_masks).

import torch
from torch.optim.optimizer import Optimizer, required

class SparseSGDM(Optimizer):
    """SGD with momentum (SGDM) that supports per-parameter gradient masks.
    Masked entries (mask == 0) receive no updates (no weight-decay, no grad, and momentum is zeroed).
    """
    def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, gradient_masks=None):
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        super().__init__(params, defaults)

        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov requires a momentum and zero dampening")

        # If gradient_masks provided, it must align with the flattened param list
        if gradient_masks is not None:
            params_list = []
            for group in self.param_groups:
                params_list.extend(group['params'])
            if len(gradient_masks) != len(params_list):
                raise ValueError("gradient_masks length must match number of parameters")
            for p, m in zip(params_list, gradient_masks):
                # store mask in state; ensure proper device/dtype when used
                self.state[p]['mask'] = m

    def set_param_masks(self, masks):
        """Replace masks after init. `masks` must be list aligned with optimizer params."""
        params_list = []
        for group in self.param_groups:
            params_list.extend(group['params'])
        if len(masks) != len(params_list):
            raise ValueError("masks length must match number of parameters")
        for p, m in zip(params_list, masks):
            self.state[p]['mask'] = m

    def _get_mask(self, p):
        st = self.state.get(p, {})
        mask = st.get('mask', None)
        if mask is None:
            return None
        # ensure mask is same device and dtype as parameter data
        return mask.to(p.device).type_as(p.data)

    def step(self, closure=None):
        """Performs a single optimization step.

        This is an adapted implementation of torch.optim.SGD.step with masking.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            lr = group['lr']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data

                mask = self._get_mask(p)  # may be None

                # Apply mask to gradient (if provided)
                if mask is not None:
                    d_p = d_p.mul(mask)

                # Weight decay contribution must also be masked (so masked weights are fully frozen)
                if weight_decay != 0:
                    wd = p.data.mul(weight_decay)
                    if mask is not None:
                        wd = wd.mul(mask)
                    d_p = d_p.add(wd)

                if momentum != 0:
                    param_state = self.state.setdefault(p, {})
                    buf = param_state.get('momentum_buffer', None)
                    if buf is None:
                        # initialize buffer with current (masked) d_p
                        buf = param_state['momentum_buffer'] = d_p.clone().detach()
                    else:
                        # ensure buffer is on the right device/dtype
                        if buf.device != d_p.device:
                            buf = buf.to(d_p.device)
                            param_state['momentum_buffer'] = buf
                        # multiply existing buffer by momentum then add current grad contribution
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

                    # If mask is provided, keep momentum zeroed in masked locations
                    if mask is not None:
                        buf.mul_(mask)

                    if nesterov:
                        d_p = d_p.add(buf, alpha=momentum)
                    else:
                        d_p = buf

                # finally apply update
                p.data.add_(d_p, alpha=-lr)

        return loss