## Program Methods

### Gradient descent-based programming

In [None]:
from typing import Any, Optional, Union

import torch
from torch import Tensor

from src.utils.pylogger import RankedLogger

log = RankedLogger(rank_zero_only=True)

# TODO: realistic한가?


@torch.no_grad()
def program_weights_gdp(
    self,
    from_reference: bool = True,
    x_values: Optional[Tensor] = None,
    learning_rate: float = 1,
    max_iter: int = 100,
    tolerance: Optional[float] = 0.01,
    w_init: Union[float, Tensor] = 0.01,
) -> None:
    """Programm the target weights into the conductances using the
    pulse update defined.

    Programming is done using the defined tile-update (e.g. SGD)
    and matching inputs (`x_values` by default `eye`).

    Args:

        from_reference: Whether to use weights from reference
            (those that were initally set with `set_weights`) or
            the current weights.
        x_values: Values to use for the read-and verify. If none
            are given, unit-vectors are used
        learning_rate: Learning rate of the optimization
        max_iter: max number of batches for the iterative programming
        tolerance: Stop the iteration loop early if the mean
            output deviation is below this number. Given in
            relation to the max output.
        w_init: initial weight matrix to start from. If given as
            float, weights are set uniform random in `[-w_init,
            w_init]`. This init weight is given directly in
            normalized conductance units and should include the
            bias row if existing.
    """

    if not from_reference or self.reference_combined_weights is None:
        self.reference_combined_weights = self.tile.get_weights()
        target_weights = self.reference_combined_weights

    if x_values is None:
        x_values = torch.eye(self.tile.get_x_size())
        # x_values = torch.rand(self.tile.get_x_size(), self.tile.get_x_size())
        x_values = x_values.to(self.device)
        target_values = x_values @ target_weights.to(self.device).T

    target_max = target_values.abs().max().item()
    if isinstance(w_init, Tensor):
        self.tile.set_weights(w_init)
    else:
        self.tile.set_weights_uniform_random(-w_init, w_init)  # type: ignore

    lr_save = self.tile.get_learning_rate()  # type: ignore
    self.tile.set_learning_rate(learning_rate)  # type: ignore

    for _ in range(max_iter):
        y = self.tile.forward(x_values, False)
        # TODO: error와 weight 2norm 사이 관계 분석
        error = y - target_values
        err_normalized = error.abs().mean().item() / target_max
        mtx_diff = self.tile.get_weights() - target_weights
        l2_norm = torch.linalg.matrix_norm(mtx_diff, ord="nuc")
        log.debug(f"Error: {l2_norm}")
        # log.debug(f"Error: {err_normalized}")
        if tolerance is not None and err_normalized < tolerance:
            break
        self.tile.update(x_values, error, False)  # type: ignore

    self.tile.set_learning_rate(lr_save)  # type: ignore

In [None]:
log = RankedLogger()


@torch.no_grad()
def program_weights_gdp2(
    self,
    batch_size: int = 5,
    learning_rate: float = 1,
    max_iter: int = 100,
    tolerance: Optional[float] = 0.01,
    w_init: Union[float, Tensor] = 0.01,
) -> None:
    """Programm the target weights into the conductances using the
    pulse update defined.

    Variable batch version of the `program_weights_gdp` method.
    """
    target_weights = self.tile.get_weights()

    input_size = self.tile.get_x_size()
    x_values = torch.eye(input_size)
    x_values = x_values.to(self.device)
    target_values = x_values @ target_weights.to(self.device).T

    target_max = target_values.abs().max().item()
    if isinstance(w_init, Tensor):
        self.tile.set_weights(w_init)
    else:
        self.tile.set_weights_uniform_random(-w_init, w_init)  # type: ignore

    lr_save = self.tile.get_learning_rate()  # type: ignore
    self.tile.set_learning_rate(learning_rate)  # type: ignore

    for i in range(max_iter):
        start_idx = i * batch_size
        end_idx = (i + 1) * batch_size

        if end_idx > len(x_values):
            # Calculate how much we exceed the length
            exceed_length = end_idx - len(x_values)

            # Slice the arrays and concatenate the exceeded part from the beginning
            x = torch.concatenate((x_values[start_idx:], x_values[:exceed_length]))
            target = torch.concatenate((target_values[start_idx:], target_values[:exceed_length]))
        else:
            x = x_values[start_idx:end_idx]
            target = target_values[start_idx:end_idx]

        y = self.tile.forward(x, False)
        error = y - target
        err_normalized = error.abs().mean().item() / target_max
        mtx_diff = self.tile.get_weights() - target_weights
        nuc_norm = torch.linalg.matrix_norm(mtx_diff, ord="nuc")
        log.debug(f"Error: {nuc_norm}")
        # log.debug(f"Error: {err_normalized}")
        if tolerance is not None and nuc_norm < tolerance:
            break
        self.tile.update(x, error, False)  # type: ignore

    self.tile.set_learning_rate(lr_save)  # type: ignore

### Proposed method(SVD)

In [None]:
def compensate_half_selection(v: Tensor) -> Tensor:
    """Compensate the half-selection effect for a vector.

    Args:
        v: Vector to compensate.

    Returns:
        Compensated vector.
    """
    return v


@torch.no_grad()
def program_weights_svd(
    self,
    max_iter: int = 100,
    use_rank_as_criterion: bool = False,
    tolerance: Optional[float] = 0.01,
    w_init: Union[float, Tensor] = 0.0,
    rank_atol: Optional[float] = 1e-2,
    svd_once: bool = False,
    **kwargs: Any,
) -> None:
    """
    Perform singular value decomposition (SVD) based weight programming.

    Args:
        use_rank_as_criterion (bool, optional): Use rank as stopping criterion. If False, use max_iter. Defaults to False.
        max_iter (int, optional): Maximum number of iterations. Defaults to 100.
        tolerance (float, optional): Tolerance for convergence. Defaults to 0.01.
        w_init (Union[float, Tensor], optional): Initial value for weights. Defaults to 0.01.
        rank_atol (float, optional): Absolute tolerance for numerical rank computation. Defaults to 1e-6.
        rank_rtol (float, optional): Relative tolerance for numerical rank computation. Defaults to 1e-6.
        svd_once (bool, optional): Flag indicating whether to perform SVD only once. Defaults to False.
        **kwargs: Additional keyword arguments.
    Returns:
        None
    """
    target_weights = self.tile.get_weights()
    # target_weights = self.tile.get_weights() if target_weights is None else target_weights

    if isinstance(w_init, Tensor):
        self.tile.set_weights(w_init)
    else:
        self.tile.set_weights_uniform_random(-w_init, w_init)  # type: ignore

    lr_save = self.tile.get_learning_rate()  # type: ignore
    # x_values = torch.eye(self.tile.get_x_size())
    # x_values = x_values.to(self.device)
    # target_values = x_values @ target_weights.to(self.device).T
    # target_max = target_values.abs().max().item()
    self.tile.set_learning_rate(1)  # type: ignore
    # since tile.update() updates w -= lr*delta_w so flip the sign
    diff = self.tile.get_weights() - target_weights
    # normalize diff matrix
    U, S, Vh = torch.linalg.svd(diff.double(), full_matrices=False)
    # rank = torch.linalg.matrix_rank(diff)
    if rank_atol is None:
        rank_atol = S.max() * max(diff.shape) * torch.finfo(S.dtype).eps

    rank = torch.sum(S > rank_atol).item()
    i = 0
    max_iter = min(max_iter, rank) if use_rank_as_criterion else max_iter
    for _ in range(max_iter):
        u = U[:, i]
        v = Vh[i, :]
        # uv_ratio = torch.sqrt(u/v)
        uv_ratio = 1
        sqrt_s = torch.sqrt(S[i])
        v *= uv_ratio * sqrt_s
        u *= sqrt_s / uv_ratio
        u1, v1 = compensate_half_selection(u), compensate_half_selection(v)
        self.tile.update(v1.float(), u1.float(), False)

        # TODO: self.get_weights()
        diff = self.tile.get_weights() - target_weights
        U, S, Vh = torch.linalg.svd(diff.double(), full_matrices=False)
        nuc_norm = S.sum() if svd_once else torch.linalg.matrix_norm(diff, ord="nuc")
        log.debug(f"Error: {nuc_norm}")
        # y = self.tile.forward(x_values, False)
        # # TODO: error와 weight 2norm 사이 관계 분석
        # error = y - target_values
        # err_normalized = error.abs().mean().item() / target_max
        # log.debug(f"Error: {err_normalized}")
        if tolerance is not None and nuc_norm < tolerance:
            break
        elif svd_once:
            i += 1
        else:
            pass

    self.tile.set_learning_rate(lr_save)  # type: ignore

## Experimental Results

### Util & setup

In [None]:
# form log list[string], search Error: {err_normalized} pattern and extract the value
# into list
import re

from src.utils.logging_utils import LogCapture


def extract_error(log_list, prefix: str = "Error: ") -> list:
    err_list = []

    for log in log_list:
        if prefix in log:
            err_list.append(float(re.findall(prefix + r"([0-9.e-]+)", log)[0]))

    return err_list

In [None]:
from matplotlib import pyplot as plt


def plot_singular_values(Ws: tuple[torch.Tensor]):
    for w in Ws:
        s = torch.linalg.svdvals(w)
        plt.plot(s)
    plt.yscale("log")
    plt.xlabel("Singular Value Index")
    plt.ylabel("Singular Value")
    plt.title("Singular Values of Weight Matrix")
    plt.show()

In [None]:
import torch

input_size = 100
output_size = 50
batch_size = 1

# generate low rank matrix
rank = 50
w = torch.randn(input_size, output_size).double()
w.clamp_(-1, 1)
if rank < min(w.shape):
    u, s, v = torch.svd(w)
    w = torch.mm(u[:, :rank], torch.mm(torch.diag(s[:rank]), v[:, :rank].t()))

### AnalogTile

In [None]:
from aihwkit.simulator.configs import FloatingPointRPUConfig, SingleRPUConfig
from aihwkit.simulator.configs.devices import (
    ConstantStepDevice,
    DriftParameter,
    ExpStepDevice,
    FloatingPointDevice,
    IdealDevice,
    SimpleDriftParameter,
)
from aihwkit.simulator.configs.utils import (
    InputRangeParameter,
    PrePostProcessingParameter,
    UpdateParameters,
)
from aihwkit.simulator.parameters.enums import PulseType
from aihwkit.simulator.presets.configs import IdealizedPreset, PCMPreset
from aihwkit.simulator.presets.devices import IdealizedPresetDevice
from aihwkit.simulator.tiles import FloatingPointTile
from aihwkit.simulator.tiles.analog import AnalogTile

pre_post_cfg = PrePostProcessingParameter(input_range=InputRangeParameter(enable=True))
# device_cfg = IdealDevice()
device_cfg = ConstantStepDevice()
update_cfg = UpdateParameters(pulse_type=PulseType.STOCHASTIC_COMPRESSED)

# rpuconfig = SingleRPUConfig(update=update_cfg, device=device_cfg, pre_post=pre_post_cfg)
# rpuconfig = FloatingPointRPUConfig()
rpuconfig = SingleRPUConfig(device=device_cfg, update=update_cfg)
rpuconfig.forward.is_perfect = True
rpuconfig.forward.out_noise = 0.0
# rpuconfig = IdealizedPreset()
atile = AnalogTile(output_size, input_size, rpu_config=rpuconfig)  # without periphery
atile_dic = {}
atile.state_dict(atile_dic)
atile2 = AnalogTile(output_size, input_size, rpu_config=rpuconfig)
atile2.load_state_dict(atile_dic, assign=True)
print(rpuconfig)

In [None]:
rpuconfig.device.__dict__

In [None]:
from aihwkit.simulator.tiles.periphery import TileWithPeriphery

# enroll the programming methods
atile.program_weights = program_weights_gdp2.__get__(atile, TileWithPeriphery)
atile2.program_weights = program_weights_svd.__get__(atile2, TileWithPeriphery)

In [None]:
with LogCapture() as logc:
    atile.tile.set_weights(w.clone().T)
    atile.program_weights(batch_size=batch_size, tolerance=1e-10, max_iter=1000)
    log_list1 = logc.get_log_list()

In [None]:
with LogCapture() as logc:
    # atile2.set_weights(weight=w, realistic=True)
    atile2.tile.set_weights(w.clone().T)
    atile2.program_weights(
        max_iter=1000, tolerance=1e-10, svd_once=False, target_weights=w.clone().T
    )
    log_list2 = logc.get_log_list()

In [None]:
W = (w.T - atile.tile.get_weights(), w.T - atile2.tile.get_weights())
# plot_singular_values(W)

In [None]:
print(
    f" nulcear norm of atile: {torch.linalg.matrix_norm(W[0], ord='nuc')}, atile2: {torch.linalg.matrix_norm(W[1], ord='nuc')}"
)

In [None]:
log_list2[-10:]

In [None]:
err_list1 = extract_error(log_list1)
err_list2 = extract_error(log_list2)
plt.semilogy(err_list1)
plt.semilogy(err_list2)
# set legend
plt.legend([f"gdp-seq(batchsize {batch_size})", "svd"])
plt.xlabel("Iteration")
plt.ylabel("Nuclear norm of weight error (sum of singular values)")
plt.title("Error vs Iteration @ {}x{}, rank={}".format(input_size, output_size, rank))
plt.show()

#### GDP batch-size effect

In [None]:
for batch_size_ in [1, 5, 10, 20, 50, input_size]:
    with LogCapture() as logc:
        atile.tile.set_weights(w.T)
        atile.program_weights(batch_size=batch_size_)
        log_list = logc.get_log_list()
    err_list = extract_error(log_list)
    num_iter = len(err_list)
    plt.semilogy(err_list, label=f"batch_size={batch_size_}")
plt.legend()
plt.xlabel("Iteration")
plt.ylabel("Nuclear norm of weight error")
plt.title(
    "{}x{} rank={} matrix with {}".format(
        input_size, output_size, rank, atile.rpu_config.device.__class__.__name__
    )
)

### d2d variaton

In [None]:
# print dataclass fields
atile.rpu_config.device.__dict__

In [None]:
w.T[:5, :5]

In [None]:
# check whether the element wise perturbation is applied

atile.tile.set_weights(w.T)
wtile = atile.tile.get_weights()
torch.allclose(wtile, w.T)

### CustomTile

In [None]:
from aihwkit.simulator.tiles.custom import CustomTile

ctile = CustomTile(output_size, input_size)
ctile.get_weights(realistic=True)

### RealisticTile(Ours)

In [None]:
from src.prog_scheme.realistic import RealisticTile, RPUConfigwithProgram

rpu_config = RPUConfigwithProgram(program_weights=program_weights_gdp)
ctile = RealisticTile(output_size, input_size, rpu_config=rpu_config)

rpu_config2 = RPUConfigwithProgram(program_weights=program_weights_svd)
ctile2 = RealisticTile(output_size, input_size, rpu_config=rpu_config2)

In [None]:
print(rpu_config)

In [None]:
with LogCapture() as logc:
    ctile.set_weights(w, realistic=True)
    log_list = logc.get_log_list()

with LogCapture() as logc:
    ctile2.set_weights(w, realistic=True)
    log_list2 = logc.get_log_list()

In [None]:
# extract error and plot
import matplotlib.pyplot as plt

err_list = extract_error(log_list)
err_list2 = extract_error(log_list2)

plt.plot(err_list, label="gpc")
plt.plot(err_list2, label="svd")
plt.legend()
plt.xlabel("Iteration")
plt.ylabel("Error")
plt.title("Error vs Iteration")
plt.show()

# ETC

only `AnalogTile` which inherits `TileWithPeriphery` class has `program_weights` method

`program_weights` method implements "Gradient descent-based programming of analog in-memory computing cores" by default

`set_weights` method is used to set the weights of the analog tile to the given values\
`program_weights` method is internally called by `set_weights` method to program the weights of the analog tile\

`get_weights` method is used to get the weights of the analog tile\
`read_weights` method is used to read the weights of the analog tile with read noise

In [None]:
from aihwkit.nn import AnalogLinear
from aihwkit.optim import AnalogSGD

In [None]:
digital_layer = torch.nn.Linear(input_size, output_size, bias=False)
layer = AnalogLinear.from_digital(digital_layer, rpuconfig)

In [None]:
optimizer = AnalogSGD(layer.parameters(), lr=0.005)
losses = []
for _ in range(1000):
    x = torch.rand(input_size)
    yhat = layer(x)
    loss = (yhat**2).sum()
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
# plot losses
import matplotlib.pyplot as plt

plt.plot(losses)