Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed precision compound #159

Merged
merged 2 commits into from Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pylintrc
Expand Up @@ -285,7 +285,8 @@ good-names=i,
z,
d,
t,
fn
fn,
nu,

# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Expand Up @@ -45,6 +45,12 @@ The format is based on [Keep a Changelog], and this project adheres to
particular device and optimizer choice. (\#144)
* Utility for visualization the pulse response properties of a given
device configuration. (\#146)
* Optional power-law drift during analog training (\#158)
* A new abstract device (`MixedPrecisionCompound`) implementing an SGD
optimizer that computes the rank update in digital (assuming digital
high precision storage) and then transfers the matrix sequentually to
the anlog device, instead of using the default fully parallel pulsed
update. (#159)

#### Fixed

Expand Down
2 changes: 1 addition & 1 deletion src/aihwkit/simulator/configs/__init__.py
Expand Up @@ -14,5 +14,5 @@

from .configs import (
FloatingPointRPUConfig, InferenceRPUConfig, SingleRPUConfig,
UnitCellRPUConfig
UnitCellRPUConfig, DigitalRankUpdateRPUConfig
)
33 changes: 32 additions & 1 deletion src/aihwkit/simulator/configs/configs.py
Expand Up @@ -17,7 +17,7 @@

from aihwkit.simulator.configs.devices import (
ConstantStepDevice, FloatingPointDevice, IdealDevice, PulsedDevice,
UnitCell
UnitCell, DigitalRankUpdateCell
)
from aihwkit.simulator.configs.helpers import (
_PrintableMixin, tile_parameters_to_bindings
Expand Down Expand Up @@ -139,3 +139,34 @@ class InferenceRPUConfig(_PrintableMixin):
def as_bindings(self) -> devices.AnalogTileParameter:
"""Return a representation of this instance as a simulator bindings object."""
return tile_parameters_to_bindings(self)


@dataclass
class DigitalRankUpdateRPUConfig(_PrintableMixin):
"""Configuration for an analog (unit cell) resistive processing unit
where the rank update is done in digital.

Note that for forward and backward, an analog crossbar is still
used, and during update the digitally computed rank update is
transferred to the analog crossbar using pulses.
"""

bindings_class: ClassVar[Type] = devices.AnalogTileParameter

device: DigitalRankUpdateCell = field(default_factory=DigitalRankUpdateCell)
"""Parameters that modify the behavior of the pulsed device."""

forward: IOParameters = field(default_factory=IOParameters)
"""Input-output parameter setting for the forward direction."""

backward: IOParameters = field(default_factory=IOParameters)
"""Input-output parameter setting for the backward direction."""

update: UpdateParameters = field(default_factory=UpdateParameters)
"""Parameter for the analog part of the update, that is the transfer
from the digital buffer to the devices.
"""

def as_bindings(self) -> devices.AnalogTileParameter:
"""Return a representation of this instance as a simulator bindings object."""
return tile_parameters_to_bindings(self)
117 changes: 116 additions & 1 deletion src/aihwkit/simulator/configs/devices.py
Expand Up @@ -12,7 +12,7 @@

"""Configuration for Analog (Resistive Device) tiles."""

# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-instance-attributes,too-many-lines

from copy import deepcopy
from dataclasses import dataclass, field
Expand Down Expand Up @@ -905,3 +905,118 @@ def as_bindings(self) -> devices.TransferResistiveDeviceParameter:
raise ConfigError("Could not add unit cell device parameter")

return transfer_parameters


###############################################################################
# Specific compound-devices with digital rank update
###############################################################################

@dataclass
class DigitalRankUpdateCell(_PrintableMixin):
"""Parameters that modify the behavior of the digital rank update cell.

This is the base class for devices that compute the rank update in
digital and then (occasionally) transfer the information to the
(analog) crossbar array that is used during forward and backward.
"""

bindings_class: ClassVar[Type] = devices.AbstractResistiveDeviceParameter

device: PulsedDevice = field(default_factory=ConstantStepDevice)
"""(Analog) device that are used for forward and backward."""

def as_bindings(self) -> devices.AbstractResistiveDeviceParameter:
"""Return a representation of this instance as a simulator bindings object."""
raise NotImplementedError

def requires_diffusion(self) -> bool:
"""Return whether device has diffusion enabled."""
return self.device.requires_diffusion()

def requires_decay(self) -> bool:
"""Return whether device has decay enabled."""
return self.device.requires_decay()


@dataclass
class MixedPrecisionCompound(DigitalRankUpdateCell):
r"""Abstract device model that takes 1 (analog) device and
implements a transfer-based learning rule, where the outer product
is computed in digital.

Here, the outer product of the activations and error is done on a
full-precision floating-point :math:`\chi` matrix. Then, with a
threshold given by the ``granularity``, pulses will be applied to
transfer the information row-by-row to the analog matrix.

For details, see `Nandakumar et al. Front. in Neurosci. (2020)`_.

Note:
This version of update is different from a parallel update in
analog other devices are implementing with stochastic pulsing,
as here :math:`{\cal O}(n^2)` digital computations are needed
to compute the outer product (rank update). This need for
digital compute in potentially high precision might result in
inferior run time and power estimates in real-world
applications, although sparse integer products can potentially
be employed to speed up to improve run time estimates. For
details, see discussion in `Nandakumar et al. Front. in
Neurosci. (2020)`_.

.. _`Nandakumar et al. Front. in Neurosci. (2020)`_: https://doi.org/10.3389/fnins.2020.00406
"""

bindings_class: ClassVar[Type] = devices.MixedPrecResistiveDeviceParameter

transfer_every: int = 1
"""Transfers every :math:`n` mat-vec operations (rounded to
multiples/ratios of m_batch).

Standard setting is 1.0 for mixed precision, but it could potentially be
reduced to get better run time estimates.
"""

n_rows_per_transfer: int = -1
r"""How many consecutive rows to write to the tile from the
:math:`\chi` matrix. -1 means full matrix read each transfer event.
"""

random_row: bool = False
"""Whether to select a random starting row for each transfer
event and not take the next row that was previously not
transferred as a starting row (the default).
"""

granularity: float = 0.0
r"""Granularity of the device that is used to calculate the number of
pulses transferred from :math:`\chi` to analog.

If 0, it will take ``dw_min`` from the analog device used.
"""

n_x_bins: int = 0
"""The number of bins to discretize (symmetrically around zero) the
activation before computing the outer product.

Dynamic quantization is used by computing the absolute max value of each
input. Quantization can be turned off by setting this to 0.
"""

n_d_bins: int = 0
"""The number of bins to discretize (symmetrically around zero) the
error before computing the outer product.

Dynamic quantization is used by computing the absolute max value of each
error vector. Quantization can be turned off by setting this to 0.
"""

def as_bindings(self) -> devices.MixedPrecResistiveDeviceParameter:
"""Return a representation of this instance as a simulator bindings object."""

mixed_prec_parameter = parameters_to_bindings(self)
param_device = self.device.as_bindings()

if not mixed_prec_parameter.set_device_parameter(param_device):
raise ConfigError("Could not add device parameter")

return mixed_prec_parameter
5 changes: 2 additions & 3 deletions src/aihwkit/simulator/configs/helpers.py
Expand Up @@ -25,9 +25,8 @@ def parameters_to_bindings(params: Any) -> Any:
result = params.bindings_class()
for field, value in params.__dict__.items():
# Convert enums to the bindings enums.
if field == 'unit_cell_devices':
# Exclude `unit_cell_devices`, as it is a special field that is not
# present in the bindings.
if field in ('unit_cell_devices', 'device'):
# Exclude special fields that are not present in the bindings.
continue

if isinstance(value, Enum):
Expand Down
97 changes: 95 additions & 2 deletions src/aihwkit/simulator/presets/configs.py
Expand Up @@ -15,10 +15,11 @@
from dataclasses import dataclass, field

from aihwkit.simulator.configs.configs import (
SingleRPUConfig, UnitCellRPUConfig
SingleRPUConfig, UnitCellRPUConfig, DigitalRankUpdateRPUConfig
)
from aihwkit.simulator.configs.devices import (
PulsedDevice, TransferCompound, UnitCell, VectorUnitCell
PulsedDevice, TransferCompound, UnitCell, VectorUnitCell,
DigitalRankUpdateCell, MixedPrecisionCompound
)
from aihwkit.simulator.configs.utils import (
IOParameters, UpdateParameters, VectorUnitCellUpdatePolicy
Expand Down Expand Up @@ -375,3 +376,95 @@ class TikiTakaIdealizedPreset(UnitCellRPUConfig):
forward: IOParameters = field(default_factory=PresetIOParameters)
backward: IOParameters = field(default_factory=PresetIOParameters)
update: UpdateParameters = field(default_factory=PresetUpdateParameters)


# Mixed precision presets

@dataclass
class MixedPrecisionReRamESPreset(DigitalRankUpdateRPUConfig):
"""Configuration using Mixed-precision with
class:`ReRamESPresetDevice` and standard ADC/DAC hardware
etc configuration."""

device: DigitalRankUpdateCell = field(
default_factory=lambda: MixedPrecisionCompound(
device=ReRamESPresetDevice(),
))
forward: IOParameters = field(default_factory=PresetIOParameters)
backward: IOParameters = field(default_factory=PresetIOParameters)
update: UpdateParameters = field(default_factory=PresetUpdateParameters)


@dataclass
class MixedPrecisionReRamSBPreset(DigitalRankUpdateRPUConfig):
"""Configuration using Mixed-precision with
class:`ReRamSBPresetDevice` and standard ADC/DAC hardware
etc configuration."""

device: DigitalRankUpdateCell = field(
default_factory=lambda: MixedPrecisionCompound(
device=ReRamSBPresetDevice(),
))
forward: IOParameters = field(default_factory=PresetIOParameters)
backward: IOParameters = field(default_factory=PresetIOParameters)
update: UpdateParameters = field(default_factory=PresetUpdateParameters)


@dataclass
class MixedPrecisionCapacitorPreset(DigitalRankUpdateRPUConfig):
"""Configuration using Mixed-precision with
class:`CapacitorPresetDevice` and standard ADC/DAC hardware
etc configuration."""

device: DigitalRankUpdateCell = field(
default_factory=lambda: MixedPrecisionCompound(
device=CapacitorPresetDevice(),
))
forward: IOParameters = field(default_factory=PresetIOParameters)
backward: IOParameters = field(default_factory=PresetIOParameters)
update: UpdateParameters = field(default_factory=PresetUpdateParameters)


@dataclass
class MixedPrecisionEcRamPreset(DigitalRankUpdateRPUConfig):
"""Configuration using Mixed-precision with
class:`EcRamPresetDevice` and standard ADC/DAC hardware
etc configuration."""

device: DigitalRankUpdateCell = field(
default_factory=lambda: MixedPrecisionCompound(
device=EcRamPresetDevice(),
))
forward: IOParameters = field(default_factory=PresetIOParameters)
backward: IOParameters = field(default_factory=PresetIOParameters)
update: UpdateParameters = field(default_factory=PresetUpdateParameters)


@dataclass
class MixedPrecisionIdealizedPreset(DigitalRankUpdateRPUConfig):
"""Configuration using Mixed-precision with
class:`IdealizedPresetDevice` and standard ADC/DAC hardware
etc configuration."""

device: DigitalRankUpdateCell = field(
default_factory=lambda: MixedPrecisionCompound(
device=IdealizedPresetDevice(),
))
forward: IOParameters = field(default_factory=PresetIOParameters)
backward: IOParameters = field(default_factory=PresetIOParameters)
update: UpdateParameters = field(default_factory=PresetUpdateParameters)


@dataclass
class MixedPrecisionGokmenVlasovPreset(DigitalRankUpdateRPUConfig):
"""Configuration using Mixed-precision with
class:`GokmenVlasovPresetDevice` and standard ADC/DAC hardware
etc configuration."""

device: DigitalRankUpdateCell = field(
default_factory=lambda: MixedPrecisionCompound(
device=GokmenVlasovPresetDevice(),
))
forward: IOParameters = field(default_factory=PresetIOParameters)
backward: IOParameters = field(default_factory=PresetIOParameters)
update: UpdateParameters = field(default_factory=PresetUpdateParameters)
2 changes: 2 additions & 0 deletions src/aihwkit/simulator/rpu_base_src/rpu_base.h
Expand Up @@ -15,6 +15,8 @@
#include "rpu_difference_device.h"
#include "rpu_expstep_device.h"
#include "rpu_linearstep_device.h"
#include "rpu_mixedprec_device.h"
#include "rpu_mixedprec_device_base.h"
#include "rpu_pulsed.h"
#include "rpu_simple_device.h"
#include "rpu_transfer_device.h"
Expand Down
43 changes: 43 additions & 0 deletions src/aihwkit/simulator/rpu_base_src/rpu_base_devices.cpp
Expand Up @@ -25,6 +25,7 @@ void declare_rpu_devices(py::module &m) {
using VectorParam = RPU::VectorRPUDeviceMetaParameter<T>;
using DifferenceParam = RPU::DifferenceRPUDeviceMetaParameter<T>;
using TransferParam = RPU::TransferRPUDeviceMetaParameter<T>;
using MixedPrecParam = RPU::MixedPrecRPUDeviceMetaParameter<T>;

/*
* Trampoline classes for allowing inheritance.
Expand Down Expand Up @@ -211,6 +212,24 @@ void declare_rpu_devices(py::module &m) {
}
};

class PyMixedPrecParam : public MixedPrecParam {
public:
std::string getName() const override {
PYBIND11_OVERLOAD(std::string, MixedPrecParam, getName, );
}
MixedPrecParam *clone() const override {
PYBIND11_OVERLOAD(MixedPrecParam *, MixedPrecParam, clone, );
}
RPU::DeviceUpdateType implements() const override {
PYBIND11_OVERLOAD(RPU::DeviceUpdateType, MixedPrecParam, implements, );
}
RPU::MixedPrecRPUDevice<T> *
createDevice(int x_size, int d_size, RPU::RealWorldRNG<T> *rng) override {
PYBIND11_OVERLOAD(
RPU::MixedPrecRPUDevice<T> *, MixedPrecParam, createDevice, x_size, d_size, rng);
}
};

/*
* Python class definitions.
*/
Expand Down Expand Up @@ -445,6 +464,30 @@ void declare_rpu_devices(py::module &m) {
return ss.str();
});

py::class_<MixedPrecParam, PyMixedPrecParam, SimpleParam>(m, "MixedPrecResistiveDeviceParameter")
.def(py::init<>())
.def_readwrite("transfer_every", &MixedPrecParam::transfer_every)
.def_readwrite("n_rows_per_transfer", &MixedPrecParam::n_rows_per_transfer)
.def_readwrite("random_row", &MixedPrecParam::random_row)
.def_readwrite("granularity", &MixedPrecParam::granularity)
.def_readwrite("compute_sparsity", &MixedPrecParam::compute_sparsity)
.def_readwrite("n_x_bins", &MixedPrecParam::n_x_bins)
.def_readwrite("n_d_bins", &MixedPrecParam::n_d_bins)
.def(
"set_device_parameter",
[](MixedPrecParam &self, const RPU::AbstractRPUDeviceMetaParameter<T> &dp) {
return self.setDevicePar(dp);
},
py::arg("parameter"),
R"pbdoc(
Set a pulsed base device parameter of a mixed precision device.
)pbdoc")
.def("__str__", [](MixedPrecParam &self) {
std::stringstream ss;
self.printToStream(ss);
return ss.str();
});

/**
* Helper enums.
**/
Expand Down