In [1]:
import torch
import torch.nn.functional as F
from torch.nn.functional import (
    _mha_shape_check, _in_projection_packed, _in_projection, pad, softmax, dropout, linear
    )
from torch.nn.modules.activation import MultiheadAttention
import uuid
from typing import Dict, Optional

from torch import Tensor
from torch.overrides import (
    has_torch_function, has_torch_function_unary, has_torch_function_variadic,
    handle_torch_function)

In [13]:
class FairseqIncrementalState(object):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.init_incremental_state()

    def init_incremental_state(self):
        self._incremental_state_id = str(uuid.uuid4())

    def _get_full_incremental_state_key(self, key: str) -> str:
        return "{}.{}".format(self._incremental_state_id, key)

    def get_incremental_state(
        self,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
        key: str,
    ) -> Optional[Dict[str, Optional[Tensor]]]:
        """Helper for getting incremental state for an nn.Module."""
        full_key = self._get_full_incremental_state_key(key)
        if incremental_state is None or full_key not in incremental_state:
            return None
        return incremental_state[full_key]

    def set_incremental_state(
        self,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
        key: str,
        value: Dict[str, Optional[Tensor]],
    ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
        """Helper for setting incremental state for an nn.Module."""
        if incremental_state is not None:
            full_key = self._get_full_incremental_state_key(key)
            print(full_key)
            incremental_state[full_key] = value
        return incremental_state

# @with_incremental_state
# as decorator, so the decorated class would also obtain those functions.
def with_incremental_state(cls):
    cls.__bases__ = (FairseqIncrementalState,) + tuple(
        b for b in cls.__bases__ if b != FairseqIncrementalState
    )
    return cls

In [27]:
incremental_decoder = FairseqIncrementalState()
incremental_decoder._incremental_state_id

'ae925387-78b6-4fd6-9526-eefad65d17bb'

In [28]:
inremental_state = {}
k_v = {
    "prev_key": 101,
    "prev_value": 100
}

In [31]:
incremental_decoder.set_incremental_state(inremental_state, k_v.keys(), k_v.values())

ae925387-78b6-4fd6-9526-eefad65d17bb.dict_keys(['prev_key', 'prev_value'])


{'ae925387-78b6-4fd6-9526-eefad65d17bb.ae925387-78b6-4fd6-9526-eefad65d17bb': {'prev_key': 101,
  'prev_value': 100},
 "ae925387-78b6-4fd6-9526-eefad65d17bb.dict_keys(['prev_key', 'prev_value'])": dict_values([101, 100])}

In [32]:
inremental_state

{'ae925387-78b6-4fd6-9526-eefad65d17bb.ae925387-78b6-4fd6-9526-eefad65d17bb': {'prev_key': 101,
  'prev_value': 100},
 "ae925387-78b6-4fd6-9526-eefad65d17bb.dict_keys(['prev_key', 'prev_value'])": dict_values([101, 100])}

In [37]:
incremental_decoder.get_incremental_state(inremental_state, "ae925387-78b6-4fd6-9526-eefad65d17bb")

{'prev_key': 101, 'prev_value': 100}

In [36]:
inremental_state

{'ae925387-78b6-4fd6-9526-eefad65d17bb.ae925387-78b6-4fd6-9526-eefad65d17bb': {'prev_key': 101,
  'prev_value': 100},
 "ae925387-78b6-4fd6-9526-eefad65d17bb.dict_keys(['prev_key', 'prev_value'])": dict_values([101, 100])}