# Custom Tiles

## Primitive update method

In [None]:
import torch
from aihwkit.simulator.tiles.custom import CustomSimulatorTile
from torch import Tensor

# TODO: 아래 메서드들 구현 필요
# aihwkit의 PulseType(aihwkit.simulator.parameters.enums.PulseType)들과 연동해서 구현.
# PCM inference noise 모델도 추가해보기

# Note: SimulaterTile.forward() 말고 AnalogNVM의 matmul에다가 구현해도 됨
# AnalogMVM.matmul()
# -> SimulatorTile.forward()
# -> PeripheryTile.joint_forward()
# -> AnalogFunction.forward()
# -> CustomTile.forward()
# -> PeripheryTile.read_weights()/program_weights()
# -> PeripheryTile.get_weights/set_weights(realistic=True) 에 쓰임


class _HalfSelectMixin:
    """Implements the half-selected update method."""

    def half_selection(self): ...


class HalfSelectedSimulatorTile(_HalfSelectMixin, CustomSimulatorTile):

    def set_weights(self, weight: Tensor) -> None:
        """Set the tile weights.

        Args:
            weight: ``[out_size, in_size]`` weight matrix.
        """
        raise NotImplementedError("set_weights() is not implemented")

    def get_weights(self) -> Tensor:
        """Get the tile weights.

        Returns:
            a tuple where the first item is the ``[out_size, in_size]`` weight
            matrix; and the second item is either the ``[out_size]`` bias vector
            or ``None`` if the tile is set not to use bias.
        """
        raise NotImplementedError("get_weights() is not implemented")

    def update(
        self,
        x_input: Tensor,
        d_input: Tensor,
        bias: bool = False,
        in_trans: bool = False,
        out_trans: bool = False,
        non_blocking: bool = False,
    ) -> Tensor:
        """Implements rank-1 tile update with gradient noise
        (e.g. using pulse trains).

        Note:
            Ignores additional arguments

        Raises:
            TileError: in case transposed input / output or bias is requested
        """
        super().update(x_input, d_input, bias, in_trans, out_trans, non_blocking)
        self.half_selection()

    def forward(
        self,
        x_input: Tensor,
        bias: bool = False,
        in_trans: bool = False,
        out_trans: bool = False,
        is_test: bool = False,
        non_blocking: bool = False,
    ) -> Tensor:
        return super().forward(x_input, bias, in_trans, out_trans, is_test, non_blocking)

## CustomTile & RPUConfig

In [None]:
from dataclasses import dataclass
from typing import Any, Callable, Optional, Type

from aihwkit.simulator.tiles.custom import CustomRPUConfig, CustomTile


class MyCustomTile(CustomTile):

    def __init__(
        self,
        out_size: int,
        in_size: int,
        rpu_config: Optional["RPUConfigwithProgram"],
        bias: bool = False,
        in_trans: bool = False,
        out_trans: bool = False,
    ):
        super().__init__(out_size, in_size, rpu_config, bias, in_trans, out_trans)
        # dynamically add the program_weights method
        self.program_weights = rpu_config.program_weights.__get__(self, MyCustomTile)


# TODO: dataclass에 직접 program_weights 메서드 붙여넣기?
@dataclass
class RPUConfigwithProgram(CustomRPUConfig):
    """Custom single RPU configuration."""

    program_weights: Callable[[Any], None] = None
    """Method to program the weights."""

    tile_class: Type = MyCustomTile
    """Tile class that corresponds to this RPUConfig."""

    simulator_tile_class: Type = HalfSelectedSimulatorTile
    """Simulator tile class implementing the analog forward / backward / update."""

## Program Methods

### Gradient-based programming

In [None]:
from typing import Optional, Union

from torch import Tensor

from src.utils.pylogger import RankedLogger

log = RankedLogger()

# TODO: realistic한가?


@torch.no_grad()
def program_weights_gpc(
    self,
    from_reference: bool = True,
    x_values: Optional[Tensor] = None,
    learning_rate: float = 0.1,
    max_iter: int = 100,
    tolerance: Optional[float] = 0.01,
    w_init: Union[float, Tensor] = 0.01,
) -> None:
    """Programm the target weights into the conductances using the
    pulse update defined.

    Programming is done using the defined tile-update (e.g. SGD)
    and matching inputs (`x_values` by default `eye`).

    Args:

        from_reference: Whether to use weights from reference
            (those that were initally set with `set_weights`) or
            the current weights.
        x_values: Values to use for the read-and verify. If none
            are given, unit-vectors are used
        learning_rate: Learning rate of the optimization
        max_iter: max number of batches for the iterative programming
        tolerance: Stop the iteration loop early if the mean
            output deviation is below this number. Given in
            relation to the max output.
        w_init: initial weight matrix to start from. If given as
            float, weights are set uniform random in `[-w_init,
            w_init]`. This init weight is given directly in
            normalized conductance units and should include the
            bias row if existing.
    """

    if not from_reference or self.reference_combined_weights is None:
        self.reference_combined_weights = self.tile.get_weights()
        target_weights = self.reference_combined_weights

    if x_values is None:
        x_values = torch.eye(self.tile.get_x_size())
        x_values = x_values.to(self.device)
        target_values = x_values @ target_weights.to(self.device).T

    target_max = target_values.abs().max().item()
    if isinstance(w_init, Tensor):
        self.tile.set_weights(w_init)
    else:
        self.tile.set_weights_uniform_random(-w_init, w_init)  # type: ignore

    lr_save = self.tile.get_learning_rate()  # type: ignore
    self.tile.set_learning_rate(learning_rate)  # type: ignore

    for _ in range(max_iter):
        y = self.tile.forward(x_values, False)
        # TODO: error와 weight 2norm 사이 관계 분석
        error = y - target_values
        err_normalized = error.abs().mean().item() / target_max
        log.debug(f"Error: {err_normalized}")
        if tolerance is not None and err_normalized < tolerance:
            break
        self.tile.update(x_values, error, False)  # type: ignore

    self.tile.set_learning_rate(lr_save)  # type: ignore

### Proposed method(SVD)

In [None]:
def compensate_half_selection(v: Tensor) -> Tensor:
    """Compensate the half-selection effect for a vector.

    Args:
        v: Vector to compensate.

    Returns:
        Compensated vector.
    """
    pass


@torch.no_grad()
def program_weights_svd(
    self,
    from_reference: bool = True,
    max_iter: int = 100,
    tolerance: Optional[float] = 0.01,
    w_init: Union[float, Tensor] = 0.01,
    rank_atol: float = 1e-6,
    rank_rtol: float = 1e-6,
    svd_once: bool = False,
    **kwargs: Any,
) -> None:
    """
    Perform singular value decomposition (SVD) based weight programming.

    Args:
        from_reference (bool, optional): Flag indicating whether to use reference combined weights. Defaults to True.
        max_iter (int, optional): Maximum number of iterations. Defaults to 100.
        tolerance (float, optional): Tolerance for convergence. Defaults to 0.01.
        w_init (Union[float, Tensor], optional): Initial value for weights. Defaults to 0.01.
        rank_atol (float, optional): Absolute tolerance for numerical rank computation. Defaults to 1e-6.
        rank_rtol (float, optional): Relative tolerance for numerical rank computation. Defaults to 1e-6.
        svd_once (bool, optional): Flag indicating whether to perform SVD only once. Defaults to False.
        **kwargs: Additional keyword arguments.
    Returns:
        None
    """

    if not from_reference or self.reference_combined_weights is None:
        self.reference_combined_weights = self.tile.get_weights()
        target_weights = self.reference_combined_weights

    if isinstance(w_init, Tensor):
        self.tile.set_weights(w_init)
    else:
        self.tile.set_weights_uniform_random(-w_init, w_init)  # type: ignore

    lr_save = self.tile.get_learning_rate()  # type: ignore
    self.tile.set_learning_rate(1)  # type: ignore
    U, S, Vh = torch.linalg.svd(target_weights - self.tile.get_weights())
    # compute numerical rank from S given rtol, atol
    num_rank = torch.sum(S > torch.max(rank_atol, S * rank_rtol))
    assert num_rank == torch.linalg.matrix_rank(
        target_weights - self.tile.get_weights(), rank_atol, rank_rtol
    )
    i = 0
    for _ in range(min(num_rank, max_iter)):
        u1 = U[i] * torch.sqrt(S[i])
        v1 = Vh[i] * torch.sqrt(S[i])
        u1, v1 = compensate_half_selection(u1), compensate_half_selection(v1)
        self.tile.update(u1, v1, False)
        diff = target_weights - self.get_weights(realistic=True)
        l2_norm = diff.norm()
        log.debug(f"Error: {l2_norm}")
        if tolerance is not None and l2_norm < tolerance:
            break
        elif svd_once:
            i += 1
        else:
            U, S, Vh = torch.linalg.svd(diff)

    self.tile.set_learning_rate(lr_save)  # type: ignore

## Experimental Results

### AnalogTile()

In [None]:
from aihwkit.simulator.configs import SingleRPUConfig
from aihwkit.simulator.configs.devices import ConstantStepDevice, DriftParameter
from aihwkit.simulator.configs.utils import InputRangeParameter, PrePostProcessingParameter
from aihwkit.simulator.tiles.analog import AnalogTile

input_size = 6
output_size = 3
pre_post = PrePostProcessingParameter(input_range=InputRangeParameter(enable=True))
device = ConstantStepDevice(diffusion=0, drift=DriftParameter())
rpuconfig = SingleRPUConfig(device=device, pre_post=pre_post)
atile = AnalogTile(output_size, input_size, rpu_config=rpuconfig)  # with periphery
print(rpuconfig)

In [None]:
atile.program_weights = program_weights_gpc.__get__(atile, AnalogTile)

In [None]:
import torch

from src.utils.logging_utils import LogCapture

w = torch.rand_like(atile.get_weights()[0])

with LogCapture() as logc:
    atile.set_weights(weight=w, realistic=True)
    log_list = logc.get_log_list()

In [None]:
import torch

from src.utils.logging_utils import LogCapture

w = torch.rand_like(atile.get_weights()[0])

with LogCapture() as logc:
    atile.set_weights(weight=w, realistic=True)
    log_list = logc.get_log_list()

In [None]:
# form log list[string], search Error: {err_normalized} pattern and extract the value
# into list
import re


def extract_error(log_list):
    err_list = []

    for log in log_list:
        if "Error" in log:
            err_list.append(float(re.findall(r"Error: ([0-9.]+)", log)[0]))

    return err_list

In [None]:
import matplotlib.pyplot as plt

err_list = extract_error(log_list)
plt.plot(err_list)
plt.xlabel("Iteration")
plt.ylabel("Error")
plt.title("Error vs Iteration")
plt.show()

### GPC vs SVD

In [None]:
rpu_config = RPUConfigwithProgram(program_weights=program_weights_gpc)
ctile = MyCustomTile(output_size, input_size, rpu_config=rpu_config)

rpu_config2 = RPUConfigwithProgram(program_weights=program_weights_svd)
ctile2 = MyCustomTile(output_size, input_size, rpu_config=rpu_config2)

In [None]:
print(rpu_config)

In [None]:
with LogCapture() as logc:
    ctile.set_weights(realistic=True)
    log_list = logc.get_log_list()

with LogCapture() as logc:
    ctile2.set_weights()
    log_list2 = logc.get_log_list()

In [None]:
# extract error and plot
import matplotlib.pyplot as plt

err_list = extract_error(log_list)
err_list2 = extract_error(log_list2)

plt.plot(err_list, label="gpc")
plt.plot(err_list2, label="svd")
plt.legend()
plt.xlabel("Iteration")
plt.ylabel("Error")
plt.title("Error vs Iteration")
plt.show()

# ETC

only `AnalogTile` which inherits `TileWithPeriphery` class has `program_weights` method

`program_weights` method implements "Gradient descent-based programming of analog in-memory computing cores" by default

`set_weights` method is used to set the weights of the analog tile to the given values\
`program_weights` method is internally called by `set_weights` method to program the weights of the analog tile\

`get_weights` method is used to get the weights of the analog tile\
`read_weights` method is used to read the weights of the analog tile with read noise

In [None]:
from aihwkit.nn import AnalogLinear
from aihwkit.optim import AnalogSGD

In [None]:
digital_layer = torch.nn.Linear(input_size, output_size, bias=False)
layer = AnalogLinear.from_digital(digital_layer, rpuconfig)

In [None]:
optimizer = AnalogSGD(layer.parameters(), lr=0.005)
losses = []
for _ in range(1000):
    x = torch.rand(input_size)
    yhat = layer(x)
    loss = (yhat**2).sum()
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
# plot losses
import matplotlib.pyplot as plt

plt.plot(losses)