## Primitive update method

In [None]:
import torch
from aihwkit.exceptions import TorchTileConfigError
from aihwkit.simulator.configs import TorchInferenceRPUConfig
from aihwkit.simulator.parameters.enums import WeightClipType
from aihwkit.simulator.parameters.inference import WeightClipParameter, WeightModifierParameter
from aihwkit.simulator.tiles.custom import CustomSimulatorTile
from torch import Tensor

# TODO: 아래 메서드들 구현 필요
# aihwkit의 PulseType(aihwkit.simulator.parameters.enums.PulseType)들과 연동해서 구현.
# PCM inference noise 모델 등등 추가해보기

# Note: SimulaterTile.forward() 말고 AnalogNVM의 matmul에다가 구현해도 됨
# AnalogMVM.matmul()
# -> SimulatorTile.forward()
# -> PeripheryTile.joint_forward()
# -> AnalogFunction.forward()
# -> CustomTile.forward()
# -> PeripheryTile.read_weights()/program_weights()
# -> PeripheryTile.get_weights/set_weights(realistic=True) 에 쓰임


class _HalfSelectMixin:
    """Implements the half-selected update method."""

    def half_selection(self):
        pass


class HalfSelectedSimulatorTile(_HalfSelectMixin, CustomSimulatorTile):

    def set_weights(self, weight: Tensor, **kwargs) -> None:
        """Set the tile weights.

        Args:
            weight: ``[out_size, in_size]`` weight matrix.
        """
        super().set_weights(weight, **kwargs)

    def get_weights(self) -> Tensor:
        """Get the tile weights.

        Returns:
            a tuple where the first item is the ``[out_size, in_size]`` weight
            matrix; and the second item is either the ``[out_size]`` bias vector
            or ``None`` if the tile is set not to use bias.
        """
        super().get_weights()

    def update(
        self,
        x_input: Tensor,
        d_input: Tensor,
        bias: bool = False,
        in_trans: bool = False,
        out_trans: bool = False,
        non_blocking: bool = False,
    ) -> Tensor:
        """Implements rank-1 tile update with gradient noise
        (e.g. using pulse trains).

        Note:
            Ignores additional arguments

        Raises:
            TileError: in case transposed input / output or bias is requested
        """
        super().update(x_input, d_input, bias, in_trans, out_trans, non_blocking)
        self.half_selection()

    def forward(
        self,
        x_input: Tensor,
        bias: bool = False,
        in_trans: bool = False,
        out_trans: bool = False,
        is_test: bool = False,
        non_blocking: bool = False,
    ) -> Tensor:
        if not is_test:
            noisy_weights = HalfSelectedSimulatorTile.modify_weight(
                self.weight, self._modifier, x_input.shape[0]
            )
        else:
            noisy_weights = self.weight

        ...

    @staticmethod
    def modify_weight(inp_weight: Tensor, modifier: WeightModifierParameter, batch_size: int):
        pass

    def set_config(self, rpu_config: "TorchInferenceRPUConfig") -> None:
        """Updated the configuration to allow on-the-fly changes.

        Args:
            rpu_config: configuration to use in the next forward passes.
        """
        self._f_io = rpu_config.forward
        self._modifier = rpu_config.modifier

    @torch.no_grad()
    def clip_weights(self, clip: WeightClipParameter) -> None:
        """Clip the weights. Called by InferenceTileWithperiphery.post_update_step()

        Args:
            clip: parameters specifying the clipping methof and type.

        Raises:
            NotImplementedError: For unsupported WeightClipTypes
            ConfigError: If unknown WeightClipType used.
        """

        if clip.type == WeightClipType.FIXED_VALUE:
            self.weight.data = torch.clamp(self.weight, -clip.fixed_value, clip.fixed_value)
        elif clip.type == WeightClipType.LAYER_GAUSSIAN:
            alpha = self.weight.std() * clip.sigma
            if clip.fixed_value > 0:
                alpha = min(clip.fixed_value, alpha)
            self.weight.data = torch.clamp(self.weight, -alpha, alpha)

        elif clip.type == WeightClipType.AVERAGE_CHANNEL_MAX:
            raise NotImplementedError
        else:
            raise TorchTileConfigError(f"Unknown clip type {clip.type}")

## CustomTile

In [None]:
from typing import Any, Callable, Optional, Tuple, Type

from aihwkit.exceptions import AnalogBiasConfigError, TileError
from aihwkit.simulator.tiles.base import SimulatorTile, SimulatorTileWrapper
from aihwkit.simulator.tiles.custom import CustomRPUConfig
from aihwkit.simulator.tiles.functions import AnalogFunction
from aihwkit.simulator.tiles.inference import InferenceTileWithPeriphery
from aihwkit.simulator.tiles.module import TileModule


class RealisticTile(TileModule, InferenceTileWithPeriphery, SimulatorTileWrapper):
    """_summary_

    Note) methods in below are from CustomTile
    """

    def __init__(
        self,
        out_size: int,
        in_size: int,
        rpu_config: Optional["RPUConfigwithProgram"],
        bias: bool = False,
        in_trans: bool = False,
        out_trans: bool = False,
    ):
        if in_trans or out_trans:
            raise TileError("in/out trans is not supported.")

        if not rpu_config:
            rpu_config = CustomRPUConfig()

        TileModule.__init__(self)
        SimulatorTileWrapper.__init__(
            self,
            out_size,
            in_size,
            rpu_config,  # type: ignore
            bias,
            in_trans,
            out_trans,
            torch_update=True,
        )
        InferenceTileWithPeriphery.__init__(self)

        if self.analog_bias:
            raise AnalogBiasConfigError("Analog bias is not supported for the torch tile")
        # dynamically add the program_weights method
        self.program_weights = rpu_config.program_weights.__get__(self, RealisticTile)

    def _create_simulator_tile(  # type: ignore
        self, x_size: int, d_size: int, rpu_config: "CustomRPUConfig"
    ) -> "SimulatorTile":
        """Create a simulator tile.

        Args:
            weight: 2D weight
            rpu_config: resistive processing unit configuration

        Returns:
            a simulator tile based on the specified configuration.
        """
        return rpu_config.simulator_tile_class(x_size=x_size, d_size=d_size, rpu_config=rpu_config)

    def forward(
        self, x_input: Tensor, tensor_view: Optional[Tuple] = None  # type: ignore
    ) -> Tensor:
        """Torch forward function that calls the analog context forward"""
        # pylint: disable=arguments-differ

        # to enable on-the-fly changes. However, with caution: might
        # change rpu config for backward / update while doing another forward.
        self.tile.set_config(self.rpu_config)

        out = AnalogFunction.apply(
            self.get_analog_ctx(), self, x_input, self.shared_weights, not self.training
        )

        if tensor_view is None:
            tensor_view = self.get_tensor_view(out.dim())
        out = self.apply_out_scaling(out, tensor_view)

        if self.digital_bias:
            return out + self.bias.view(*tensor_view)
        return out

## RPUConfig

In [None]:
# TODO: dataclass에 직접 program_weights 메서드 붙여넣기?
from dataclasses import dataclass, field


@dataclass
class RPUConfigwithProgram(CustomRPUConfig):
    """Custom single RPU configuration."""

    program_weights: Callable[[Any], None] = None
    """Method to program the weights."""

    tile_class: Type = RealisticTile
    """Tile class that corresponds to this RPUConfig."""

    simulator_tile_class: Type = HalfSelectedSimulatorTile
    """Simulator tile class implementing the analog forward / backward / update."""

    clip: WeightClipParameter = field(
        default_factory=lambda: WeightClipParameter(
            type=WeightClipType.FIXED_VALUE, fixed_value=1.0
        )
    )
    modifier: WeightModifierParameter = field(default_factory=WeightModifierParameter)

## Program Methods

### Gradient descent-based programming

In [None]:
from typing import Optional, Union

from torch import Tensor

from src.utils.pylogger import RankedLogger

log = RankedLogger()

# TODO: realistic한가?


@torch.no_grad()
def program_weights_gdp(
    self,
    from_reference: bool = True,
    x_values: Optional[Tensor] = None,
    learning_rate: float = 1,
    max_iter: int = 100,
    tolerance: Optional[float] = 0.01,
    w_init: Union[float, Tensor] = 0.01,
) -> None:
    """Programm the target weights into the conductances using the
    pulse update defined.

    Programming is done using the defined tile-update (e.g. SGD)
    and matching inputs (`x_values` by default `eye`).

    Args:

        from_reference: Whether to use weights from reference
            (those that were initally set with `set_weights`) or
            the current weights.
        x_values: Values to use for the read-and verify. If none
            are given, unit-vectors are used
        learning_rate: Learning rate of the optimization
        max_iter: max number of batches for the iterative programming
        tolerance: Stop the iteration loop early if the mean
            output deviation is below this number. Given in
            relation to the max output.
        w_init: initial weight matrix to start from. If given as
            float, weights are set uniform random in `[-w_init,
            w_init]`. This init weight is given directly in
            normalized conductance units and should include the
            bias row if existing.
    """

    if not from_reference or self.reference_combined_weights is None:
        self.reference_combined_weights = self.tile.get_weights()
        target_weights = self.reference_combined_weights

    if x_values is None:
        x_values = torch.eye(self.tile.get_x_size())
        # x_values = torch.rand(self.tile.get_x_size(), self.tile.get_x_size())
        x_values = x_values.to(self.device)
        target_values = x_values @ target_weights.to(self.device).T

    target_max = target_values.abs().max().item()
    if isinstance(w_init, Tensor):
        self.tile.set_weights(w_init)
    else:
        self.tile.set_weights_uniform_random(-w_init, w_init)  # type: ignore

    lr_save = self.tile.get_learning_rate()  # type: ignore
    self.tile.set_learning_rate(learning_rate)  # type: ignore

    for _ in range(max_iter):
        y = self.tile.forward(x_values, False)
        # TODO: error와 weight 2norm 사이 관계 분석
        error = y - target_values
        err_normalized = error.abs().mean().item() / target_max
        mtx_diff = self.tile.get_weights() - target_weights
        l2_norm = torch.linalg.matrix_norm(mtx_diff, ord=2)
        log.debug(f"Error: {l2_norm}")
        # log.debug(f"Error: {err_normalized}")
        if tolerance is not None and err_normalized < tolerance:
            break
        self.tile.update(x_values, error, False)  # type: ignore

    self.tile.set_learning_rate(lr_save)  # type: ignore

In [None]:
log = RankedLogger()


@torch.no_grad()
def program_weights_gdp2(
    self,
    batch_size: int = 5,
    learning_rate: float = 1,
    max_iter: int = 100,
    tolerance: Optional[float] = 0.01,
    w_init: Union[float, Tensor] = 0.01,
) -> None:
    """Programm the target weights into the conductances using the
    pulse update defined.

    Variable batch version of the `program_weights_gdp` method.
    """
    target_weights = self.tile.get_weights()

    input_size = self.tile.get_x_size()
    x_values = torch.eye(input_size)
    x_values = x_values.to(self.device)
    target_values = x_values @ target_weights.to(self.device).T

    target_max = target_values.abs().max().item()
    if isinstance(w_init, Tensor):
        self.tile.set_weights(w_init)
    else:
        self.tile.set_weights_uniform_random(-w_init, w_init)  # type: ignore

    lr_save = self.tile.get_learning_rate()  # type: ignore
    self.tile.set_learning_rate(learning_rate)  # type: ignore

    for i in range(max_iter):
        start_idx = i * batch_size
        end_idx = (i + 1) * batch_size

        if end_idx > len(x_values):
            # Calculate how much we exceed the length
            exceed_length = end_idx - len(x_values)

            # Slice the arrays and concatenate the exceeded part from the beginning
            x = torch.concatenate((x_values[start_idx:], x_values[:exceed_length]))
            target = torch.concatenate((target_values[start_idx:], target_values[:exceed_length]))
        else:
            x = x_values[start_idx:end_idx]
            target = target_values[start_idx:end_idx]

        y = self.tile.forward(x, False)
        error = y - target
        err_normalized = error.abs().mean().item() / target_max
        mtx_diff = self.tile.get_weights() - target_weights
        l2_norm = torch.linalg.matrix_norm(mtx_diff, ord=2)
        log.debug(f"Error: {l2_norm}")
        # log.debug(f"Error: {err_normalized}")
        if tolerance is not None and err_normalized < tolerance:
            break
        self.tile.update(x, error, False)  # type: ignore

    self.tile.set_learning_rate(lr_save)  # type: ignore

### Proposed method(SVD)

In [None]:
def compensate_half_selection(v: Tensor) -> Tensor:
    """Compensate the half-selection effect for a vector.

    Args:
        v: Vector to compensate.

    Returns:
        Compensated vector.
    """
    return v


@torch.no_grad()
def program_weights_svd(
    self,
    max_iter: int = 100,
    tolerance: Optional[float] = 0.01,
    w_init: Union[float, Tensor] = 0.0,
    rank_atol: Optional[float] = 1e-2,
    svd_once: bool = False,
    **kwargs: Any,
) -> None:
    """
    Perform singular value decomposition (SVD) based weight programming.

    Args:
        from_reference (bool, optional): Flag indicating whether to use reference combined weights. Defaults to True.
        max_iter (int, optional): Maximum number of iterations. Defaults to 100.
        tolerance (float, optional): Tolerance for convergence. Defaults to 0.01.
        w_init (Union[float, Tensor], optional): Initial value for weights. Defaults to 0.01.
        rank_atol (float, optional): Absolute tolerance for numerical rank computation. Defaults to 1e-6.
        rank_rtol (float, optional): Relative tolerance for numerical rank computation. Defaults to 1e-6.
        svd_once (bool, optional): Flag indicating whether to perform SVD only once. Defaults to False.
        **kwargs: Additional keyword arguments.
    Returns:
        None
    """
    target_weights = self.tile.get_weights()

    if isinstance(w_init, Tensor):
        self.tile.set_weights(w_init)
    else:
        self.tile.set_weights_uniform_random(-w_init, w_init)  # type: ignore

    lr_save = self.tile.get_learning_rate()  # type: ignore
    # x_values = torch.eye(self.tile.get_x_size())
    # x_values = x_values.to(self.device)
    # target_values = x_values @ target_weights.to(self.device).T
    # target_max = target_values.abs().max().item()
    self.tile.set_learning_rate(1)  # type: ignore
    # since tile.update() updates w -= lr*delta_w so flip the sign
    diff = self.tile.get_weights() - target_weights
    # normalize diff matrix
    U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
    # rank = torch.linalg.matrix_rank(diff)
    if rank_atol is None:
        rank_atol = S.max() * max(diff.shape) * torch.finfo(S.dtype).eps

    rank = torch.sum(S > rank_atol).item()
    i = 0
    for _ in range(min(max_iter, rank)):
        u = U[:, i]
        v = Vh[i, :]
        # uv_ratio = torch.sqrt(u/v)
        uv_ratio = 1
        sqrt_s = torch.sqrt(S[i])
        v *= uv_ratio * sqrt_s
        u *= sqrt_s / uv_ratio
        u1, v1 = compensate_half_selection(u), compensate_half_selection(v)
        self.tile.update(v1, u1, False)

        # TODO: self.get_weights()
        diff = self.tile.get_weights() - target_weights
        U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
        l2_norm = torch.linalg.matrix_norm(diff, ord=2) if svd_once else S[i]
        log.debug(f"Error: {l2_norm}")
        # y = self.tile.forward(x_values, False)
        # # TODO: error와 weight 2norm 사이 관계 분석
        # error = y - target_values
        # err_normalized = error.abs().mean().item() / target_max
        # log.debug(f"Error: {err_normalized}")
        if tolerance is not None and l2_norm < tolerance:
            break
        elif svd_once:
            i += 1
        else:
            pass

    self.tile.set_learning_rate(lr_save)  # type: ignore

## Experimental Results

### Util & setup

In [None]:
# form log list[string], search Error: {err_normalized} pattern and extract the value
# into list
import re

from src.utils.logging_utils import LogCapture


def extract_error(log_list):
    err_list = []

    for log in log_list:
        if "Error" in log:
            err_list.append(float(re.findall(r"Error: ([0-9.e-]+)", log)[0]))

    return err_list

In [None]:
import torch

input_size = 500
output_size = 300
batch_size = 5

# generate low rank matrix
rank = 20
u, s, v = torch.svd(torch.randn(input_size, output_size))
w = torch.mm(u[:, :rank], torch.mm(torch.diag(s[:rank]), v[:, :rank].t()))

### AnalogTile

In [None]:
from aihwkit.simulator.configs import SingleRPUConfig
from aihwkit.simulator.configs.devices import ConstantStepDevice, DriftParameter, IdealDevice
from aihwkit.simulator.configs.utils import (
    InputRangeParameter,
    PrePostProcessingParameter,
    UpdateParameters,
)
from aihwkit.simulator.parameters.enums import PulseType
from aihwkit.simulator.tiles.analog import AnalogTile

pre_post_cfg = PrePostProcessingParameter(input_range=InputRangeParameter(enable=True))
# device_cfg = ConstantStepDevice(diffusion=0)
device_cfg = IdealDevice()
update_cfg = UpdateParameters(pulse_type=PulseType.STOCHASTIC_COMPRESSED)
rpuconfig = SingleRPUConfig(update=update_cfg, device=device_cfg, pre_post=pre_post_cfg)
# rpuconfig = SingleRPUConfig()
atile = AnalogTile(output_size, input_size, rpu_config=rpuconfig)  # with periphery
atile_dic = {}
atile.state_dict(atile_dic)
atile2 = AnalogTile(output_size, input_size, rpu_config=rpuconfig)
atile2.load_state_dict(atile_dic, assign=True)
print(rpuconfig)

In [None]:
from aihwkit.simulator.tiles.periphery import TileWithPeriphery

atile.program_weights = program_weights_gdp2.__get__(atile, TileWithPeriphery)
atile2.program_weights = program_weights_svd.__get__(atile2, TileWithPeriphery)

In [None]:
with LogCapture() as logc:
    atile.tile.set_weights(w.T)
    atile.program_weights(batch_size=batch_size, tolerance=1e-5)
    log_list1 = logc.get_log_list()

In [None]:
with LogCapture() as logc:
    # atile2.set_weights(weight=w, realistic=True)
    atile2.tile.set_weights(w.T)
    atile2.program_weights()
    log_list2 = logc.get_log_list()

In [None]:
import matplotlib.pyplot as plt

err_list = extract_error(log_list1)
err_list2 = extract_error(log_list2)
plt.semilogy(err_list)
plt.semilogy(err_list2)
# set legend
plt.legend([f"gdp-seq(batchsize {batch_size})", "svd"])
plt.xlabel("Iteration")
plt.ylabel("L2 norm of weight error (largest singular value)")
plt.title("Error vs Iteration @ {}x{}, rank={}".format(input_size, output_size, rank))
plt.show()

#### GDP batch-size effect

In [None]:
for batch_size in [1, 5, 10, 20, 50]:
    with LogCapture() as logc:
        atile.tile.set_weights(w.T)
        atile.program_weights(batch_size=batch_size)
        log_list = logc.get_log_list()
    err_list = extract_error(log_list)
    num_iter = len(err_list)
    plt.semilogy(err_list, label=f"batch_size={batch_size}")
plt.legend()
plt.xlabel("Iteration")
plt.ylabel("L2 norm of weight error")
plt.title(
    "{}x{} rank={} matrix with {}".format(
        input_size, output_size, rank, atile.rpu_config.device.__class__.__name__
    )
)

### CustomTile

In [None]:
from aihwkit.simulator.tiles.custom import CustomTile

ctile = CustomTile(output_size, input_size)
ctile.get_weights(realistic=True)

### RealisticTile(Ours)

In [None]:
rpu_config = RPUConfigwithProgram(program_weights=program_weights_gdp)
ctile = RealisticTile(output_size, input_size, rpu_config=rpu_config)

rpu_config2 = RPUConfigwithProgram(program_weights=program_weights_svd)
ctile2 = RealisticTile(output_size, input_size, rpu_config=rpu_config2)

In [None]:
print(rpu_config)

In [None]:
with LogCapture() as logc:
    ctile.set_weights(w, realistic=True)
    log_list = logc.get_log_list()

with LogCapture() as logc:
    ctile2.set_weights(w, realistic=True)
    log_list2 = logc.get_log_list()

In [None]:
# extract error and plot
import matplotlib.pyplot as plt

err_list = extract_error(log_list)
err_list2 = extract_error(log_list2)

plt.plot(err_list, label="gpc")
plt.plot(err_list2, label="svd")
plt.legend()
plt.xlabel("Iteration")
plt.ylabel("Error")
plt.title("Error vs Iteration")
plt.show()

# ETC

only `AnalogTile` which inherits `TileWithPeriphery` class has `program_weights` method

`program_weights` method implements "Gradient descent-based programming of analog in-memory computing cores" by default

`set_weights` method is used to set the weights of the analog tile to the given values\
`program_weights` method is internally called by `set_weights` method to program the weights of the analog tile\

`get_weights` method is used to get the weights of the analog tile\
`read_weights` method is used to read the weights of the analog tile with read noise

In [None]:
from aihwkit.nn import AnalogLinear
from aihwkit.optim import AnalogSGD

In [None]:
digital_layer = torch.nn.Linear(input_size, output_size, bias=False)
layer = AnalogLinear.from_digital(digital_layer, rpuconfig)

In [None]:
optimizer = AnalogSGD(layer.parameters(), lr=0.005)
losses = []
for _ in range(1000):
    x = torch.rand(input_size)
    yhat = layer(x)
    loss = (yhat**2).sum()
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
# plot losses
import matplotlib.pyplot as plt

plt.plot(losses)