Skip to content

Commit

Permalink
Drift python addition (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
maljoras committed Mar 10, 2021
1 parent 660f013 commit 8e1db65
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/aihwkit/simulator/configs/configs.py
Expand Up @@ -80,7 +80,7 @@ class UnitCellRPUConfig(_PrintableMixin):
"""Input-output parameter setting for the backward direction."""

update: UpdateParameters = field(default_factory=UpdateParameters)
"""Parameter for the update behavior."""
"""Parameter for the parallel analog update behavior."""

def as_bindings(self) -> devices.AnalogTileParameter:
"""Return a representation of this instance as a simulator bindings object."""
Expand Down
69 changes: 44 additions & 25 deletions src/aihwkit/simulator/configs/devices.py
Expand Up @@ -23,7 +23,8 @@
_PrintableMixin, parameters_to_bindings
)
from aihwkit.simulator.configs.utils import (
IOParameters, UpdateParameters, VectorUnitCellUpdatePolicy
IOParameters, UpdateParameters, VectorUnitCellUpdatePolicy,
DriftParameter, SimpleDriftParameter
)
from aihwkit.simulator.rpu_base import devices

Expand All @@ -43,6 +44,9 @@ class FloatingPointDevice(_PrintableMixin):
lifetime: float = 0.0
r"""One over `decay_rate`, ie :math:`1/r_\text{decay}`."""

drift: SimpleDriftParameter = field(default_factory=SimpleDriftParameter)
"""Parameter governing a power-law drift."""

def as_bindings(self) -> devices.FloatingPointTileParameter:
"""Return a representation of this instance as a simulator bindings object."""
return parameters_to_bindings(self)
Expand Down Expand Up @@ -73,25 +77,32 @@ class PulsedDevice(_PrintableMixin):
Resets the weight in cross points to (around) zero with
cycle-to-cycle and systematic spread around a mean.
Important:
Reset with given parameters is only activated when
:meth:`~aihwkit.simulator.tiles.base.Base.reset_weights` is
called explicitly by the user.
**Decay**:
.. math:: w_{ij} \leftarrow w_{ij}\,(1-\alpha_\text{decay}\delta_{ij})
Weight decay is generally off and has to be activated explicitly
by using :meth:`decay` on an analog tile. Note that the device
``decay_lifetime`` parameters (1 over decay rates
:math:`\delta_{ij}`) are analog tile specific and are thus set and
fixed during RPU initialization. :math:`\alpha_\text{decay}` is a
scaling factor that can be given during run-time.
Weight decay is only activated by inserting a specific call to
:meth:`~aihwkit.simulator.tiles.base.Base.decay_weights`, which is
done automatically for a tile each mini-batch is decay is
present. Note that the device ``decay_lifetime`` parameters (1
over decay rates :math:`\delta_{ij}`) are analog tile specific and
are thus set and fixed during RPU
initialization. :math:`\alpha_\text{decay}` is a scaling factor
that can be given during run-time.
**Diffusion**:
Similar to the decay, diffusion is only activated by inserting a specific
operator. However, the parameters of the diffusion
process are set during RPU initialization and are fixed for the
remainder.
Similar to the decay, diffusion is only activated by inserting a
specific call to
:meth:`~aihwkit.simulator.tiles.base.Base.diffuse_weights`, which is
done automatically for a tile each mini-batch is diffusion is
present. The parameters of the diffusion process are set during
RPU initialization and are fixed for the remainder.
.. math:: w_{ij} \leftarrow w_{ij} + \rho_{ij} \, \xi;
Expand All @@ -101,13 +112,24 @@ class PulsedDevice(_PrintableMixin):
Note:
If diffusion happens to move the weight beyond the hard bounds of the
weight it is ensured to be clipped appropriately.
** Drift **:
Optional power-law drift setting, as described in
:class:`~aihwkit.similar.configs.utils.DriftParameter`.
Important:
Similar to reset, drift is *not* applied automatically each
mini-batch but requires an explicit call to
:meth:`~aihwkit.simulator.tiles.base.Base.drift_weights` each
time the drift should be applied.
"""

bindings_class: ClassVar[Type] = devices.PulsedResistiveDeviceParameter

construction_seed: int = 0
"""If not equal 0, will set a unique seed for hidden parameters
during construction"""
during construction."""

corrupt_devices_prob: float = 0.0
"""Probability for devices to be corrupt (weights fixed to random value
Expand All @@ -122,6 +144,9 @@ class PulsedDevice(_PrintableMixin):
diffusion_dtod: float = 0.0
"""Device-to device variation of diffusion rate in relative units."""

drift: DriftParameter = field(default_factory=DriftParameter)
"""Parameter governing a power-law drift."""

dw_min: float = 0.001
"""Mean of the minimal update step sizes across devices and directions."""

Expand Down Expand Up @@ -269,7 +294,7 @@ class IdealDevice(_PrintableMixin):

construction_seed: int = 0
"""If not equal 0, will set a unique seed for hidden parameters
during construction"""
during construction."""

diffusion: float = 0.0
"""Standard deviation of diffusion process."""
Expand Down Expand Up @@ -353,7 +378,6 @@ class LinearStepDevice(PulsedDevice):
w_{ij} &\leftarrow& \text{clip}(w_{ij},b^\text{min}_{ij},b^\text{max}_{ij})
\end{eqnarray*}
in case of additive noise. Optionally, multiplicative noise can
be chosen in which case the first equation becomes:
Expand Down Expand Up @@ -424,7 +448,7 @@ class LinearStepDevice(PulsedDevice):

allow_increasing: bool = False
"""Whether to allow the situation where update sizes increase
towards the bound instead of saturating (and thus becoming smaller)
towards the bound instead of saturating (and thus becoming smaller).
"""

mean_bound_reference: bool = True
Expand All @@ -443,7 +467,7 @@ class LinearStepDevice(PulsedDevice):

mult_noise: bool = True
"""Whether to use multiplicative noise instead of additive
cycle-to-cycle noise"""
cycle-to-cycle noise."""

write_noise_std: float = 0.0
r"""Whether to use update write noise that is added to the updated
Expand Down Expand Up @@ -477,7 +501,7 @@ class SoftBoundsDevice(PulsedDevice):

mult_noise: bool = True
"""Whether to use multiplicative noise instead of additive
cycle-to-cycle noise"""
cycle-to-cycle noise."""


@dataclass
Expand Down Expand Up @@ -752,32 +776,27 @@ class TransferCompound(UnitCell):
The weight that is seen in the forward and backward pass is
governed by the :math:`\gamma` weightening setting.
Note:
Here the devices could be either transferred in analog
(essentially within the unit cell) or on separate arrays (using
the usual (non-ideal) forward pass and update steps. This can be
set with ``transfer_forward`` and ``transfer_update``.
.. _Gokmen & Haensch (2020): https://www.frontiersin.org/articles/10.3389/fnins.2020.00103/full
"""

bindings_class: ClassVar[Type] = devices.TransferResistiveDeviceParameter

gamma: float = 0.0
r"""
Weightening factor to compute the effective SGD weight from the
r"""Weighting factor to compute the effective SGD weight from the
hidden matrices. The default scheme is:
.. math:: g^{n-1} W_0 + g^{n-2} W_1 + \ldots + g^0 W_{n-1}
"""

gamma_vec: List[float] = field(default_factory=list,
metadata={'hide_if': []})
"""
User-defined weightening can be given as a list if weights in
"""User-defined weightening can be given as a list if weights in
which case the default weightening scheme with ``gamma`` is not
used.
"""
Expand Down
118 changes: 118 additions & 0 deletions src/aihwkit/simulator/configs/utils.py
Expand Up @@ -510,3 +510,121 @@ class WeightClipParameter(_PrintableMixin):

type: WeightClipType = WeightClipType.NONE
"""Type of clipping."""


@dataclass
class SimpleDriftParameter(_PrintableMixin):
r"""Parameter for a simple power law drift.
The drift as a simple power law drift without device-to-device
variation or conductance dependence.
It computes:
.. math::
w_{ij}*\left(\frac{t + \Delta t}{t_0}\right)^(-\nu)
"""

bindings_class: ClassVar[Type] = devices.DriftParameter

nu: float = 0.0
r"""Average drift :math:`\nu` value. Need to non-zero to actually use the drift."""

t_0: float = 1.0
"""Time between write and first read.
Usually assumed in milliseconds, however, it really determines the time
units of ``time_since_last_call`` when calling the drift.
"""

reset_tol: float = 1e-7
"""Reset tolerance.
This should a number smaller than the expected weight change as it
is used to detect any changes in the weight from the last drift
call. Every change to the weight above this tolerance will reset
the drift time.
Caution:
Any write noise or diffusion on the weight might thus
interfere with the drift.
"""


@dataclass
class DriftParameter(SimpleDriftParameter):
r"""Parameter for a power law drift.
The drift is based on the model described by `Oh et al (2019)`_
It computes:
.. math::
w_{ij}*\left(\frac{t + \Delta t}{t_0}\right)^(-\nu^\text{actual}_{ij})
where the drift coefficient is drawn once at the beginning and
might depend on device. It also can depend on the actual weight
value.
The actual drift coefficient is computed as:
.. math::
\nu_{ij}^\text{actual} = \nu_{ij} - \nu_k \log \frac{(w_{ij} - w_\text{off}) / r_\text{wg}
+ g_\text{off}}{G_0} + \nu\sigma_\nu\xi
here :math:`w_{ij}` is the actual weight and `\nu_{ij}` fixed for
each device given by the mean :math:`\nu` and the device-to-device
variation: :math:`\nu_{ij} = \nu + \nu_dtod\nu\xi` and are only
drawn once at the beginning (tile instantiation). `\xi` is
Gaussian noise.
Note:
If the weight has changed from the last drift call (determined
by the ``reset_tol`` parameter), for instance due to update,
decay or noise, then the drift time :math:`t` will be reset and start
from new, however, the drift coefficients :math:`\nu_{ij}` are
*not* changed. On the other hand, if the weights has not
changed since last call, :math:`t` will accumulate the time.
Caution:
Note that the drift coefficient does *not* depend on the initially
programmed weight value at :math:`t=0` in the current
implementation (ie G0 is a constant for all devices), but
instead on the actual weight. In some materials (e.g. phase
changed materials), that might be not accurate.
.. _`Oh et al (2019)`: https://ieeexplore.ieee.org/document/8753712
"""

bindings_class: ClassVar[Type] = devices.DriftParameter

nu_dtod: float = 0.0
r"""Device-to-device variation of the :math:`\nu` values."""

nu_std: float = 0.0
r"""Cycle-to-cycle variation of :math:`\nu`.
A more realistic way to add noise of the drift might be using
``w_noise_std``.
"""

wg_ratio: float = 1.0
"""``(w_max-w_min)/(g_max-g_min)`` to convert to physical units."""

g_offset: float = 0.0
"""``g_min`` to convert to physical units."""

w_offset: float = 0.0
"""``w(g_min)``, i.e. to what value ``g_min`` is mapped to in w-space."""

nu_k: float = 0.0
r"""Variation of "math:`nu` with :math:`W`.
ie. :math:`\nu(R) = nu_0 - k \log(G/G_0)`.
See Oh et al.
"""

log_g0: float = 0.0
"""Log g0."""

w_noise_std: float = 0.0
"""Additional weight noise (Gaussian diffusion) added to the weights
after the drift is applied."""
17 changes: 16 additions & 1 deletion src/aihwkit/simulator/rpu_base_src/rpu_base_devices.cpp
Expand Up @@ -228,7 +228,8 @@ void declare_rpu_devices(py::module &m) {
})
// Properties from this class.
.def_readwrite("diffusion", &RPU::SimpleMetaParameter<T>::diffusion)
.def_readwrite("lifetime", &RPU::SimpleMetaParameter<T>::lifetime);
.def_readwrite("lifetime", &RPU::SimpleMetaParameter<T>::lifetime)
.def_readwrite("drift", &RPU::SimpleMetaParameter<T>::drift);

py::class_<RPU::PulsedMetaParameter<T>>(m, "AnalogTileParameter")
.def(py::init<>())
Expand Down Expand Up @@ -282,6 +283,20 @@ void declare_rpu_devices(py::module &m) {
.def_readwrite("w_noise", &RPU::IOMetaParameter<T>::w_noise)
.def_readwrite("w_noise_type", &RPU::IOMetaParameter<T>::w_noise_type);

py::class_<RPU::DriftParameter<T>>(m, "DriftParameter")
.def(py::init<>())
.def_readwrite("nu", &RPU::DriftParameter<T>::nu)
.def_readwrite("nu_dtod", &RPU::DriftParameter<T>::nu_dtod)
.def_readwrite("nu_std", &RPU::DriftParameter<T>::nu_std)
.def_readwrite("wg_ratio", &RPU::DriftParameter<T>::wg_ratio)
.def_readwrite("g_offset", &RPU::DriftParameter<T>::g_offset)
.def_readwrite("w_offset", &RPU::DriftParameter<T>::w_offset)
.def_readwrite("nu_k", &RPU::DriftParameter<T>::nu_k)
.def_readwrite("log_g0", &RPU::DriftParameter<T>::logG0)
.def_readwrite("t_0", &RPU::DriftParameter<T>::t0)
.def_readwrite("reset_tol", &RPU::DriftParameter<T>::reset_tol)
.def_readwrite("w_noise_std", &RPU::DriftParameter<T>::w_read_std);

// device params
py::class_<AbstractParam, PyAbstractParam, RPU::SimpleMetaParameter<T>>(
m, "AbstractResistiveDeviceParameter")
Expand Down
18 changes: 18 additions & 0 deletions src/aihwkit/simulator/rpu_base_src/rpu_base_tiles.cpp
Expand Up @@ -247,6 +247,24 @@ void declare_rpu_tiles(py::module &m) {
Args:
alpha: decay scale
)pbdoc")
.def(
"drift_weights", [](Class &self, float time_since_last_call) { self.driftWeights(time_since_last_call); },
py::arg("time_since_last_call"),
R"pbdoc(
Drift weights according to a power law::
W = W0*(delta_t/t0)^(-nu_actual)
Applies the weight drift to all unchanged weight elements
(judged by ``reset_tol``) and resets the drift for those
that have changed (nu is not re-drawn, however). Each
device might have a different version of this drift.
Args:
time_since_last_call: This is the time between the calls (``delta_t``),
typically the time to process a mini-batch for the
network.
)pbdoc")
.def(
"clip_weights",
[](Class &self, ::RPU::WeightClipParameter &wclip_par) { self.clipWeights(wclip_par); },
Expand Down

0 comments on commit 8e1db65

Please sign in to comment.